[Mlir-commits] [mlir] [mlir]Fix compose subview (PR #80551)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Tue Feb 6 11:58:50 PST 2024


================
@@ -51,67 +50,83 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
     }
 
     // Offsets, sizes and strides OpFoldResult for the combined 'SubViewOp'.
-    SmallVector<OpFoldResult> offsets, sizes, strides;
-
-    // Because we only support input strides of 1, the output stride is also
-    // always 1.
-    if (llvm::all_of(strides, [](OpFoldResult &valueOrAttr) {
-          Attribute attr = llvm::dyn_cast_if_present<Attribute>(valueOrAttr);
-          return attr && cast<IntegerAttr>(attr).getInt() == 1;
-        })) {
-      strides = SmallVector<OpFoldResult>(sourceOp.getMixedStrides().size(),
-                                          rewriter.getI64IntegerAttr(1));
-    } else {
-      return failure();
+    SmallVector<OpFoldResult> offsets, sizes, strides,
+        opStrides = op.getMixedStrides(),
+        sourceStrides = sourceOp.getMixedStrides();
+
+    // The output stride in each dimension is equal to the product of the
+    // dimensions corresponding to source and op.
+    for (auto &&[opStride, sourceStride] :
+         llvm::zip(opStrides, sourceStrides)) {
+      Attribute opStrideAttr = dyn_cast_if_present<Attribute>(opStride);
+      Attribute sourceStrideAttr = dyn_cast_if_present<Attribute>(sourceStride);
+      if (!opStrideAttr || !sourceStrideAttr)
+        return failure();
+      strides.push_back(rewriter.getI64IntegerAttr(
+          cast<IntegerAttr>(opStrideAttr).getInt() *
+          cast<IntegerAttr>(sourceStrideAttr).getInt()));
     }
 
     // The rules for calculating the new offsets and sizes are:
     // * Multiple subview offsets for a given dimension compose additively.
-    //   ("Offset by m" followed by "Offset by n" == "Offset by m + n")
+    //   ("Offset by m and Stride by k" followed by "Offset by n" == "Offset by
+    //   m + n * k")
     // * Multiple sizes for a given dimension compose by taking the size of the
     //   final subview and ignoring the rest. ("Take m values" followed by "Take
     //   n values" == "Take n values") This size must also be the smallest one
     //   by definition (a subview needs to be the same size as or smaller than
     //   its source along each dimension; presumably subviews that are larger
     //   than their sources are disallowed by validation).
-    for (auto it : llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(),
-                             op.getMixedSizes())) {
-      auto opOffset = std::get<0>(it);
-      auto sourceOffset = std::get<1>(it);
-      auto opSize = std::get<2>(it);
-
+    for (auto &&[opOffset, sourceOffset, sourceStride, opSize] :
+         llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(),
+                   sourceOp.getMixedStrides(), op.getMixedSizes())) {
       // We only support static sizes.
       if (opSize.is<Value>()) {
         return failure();
       }
-
       sizes.push_back(opSize);
       Attribute opOffsetAttr = llvm::dyn_cast_if_present<Attribute>(opOffset),
                 sourceOffsetAttr =
-                    llvm::dyn_cast_if_present<Attribute>(sourceOffset);
-
+                    llvm::dyn_cast_if_present<Attribute>(sourceOffset),
+                sourceStrideAttr =
+                    llvm::dyn_cast_if_present<Attribute>(sourceStride);
       if (opOffsetAttr && sourceOffsetAttr) {
+
         // If both offsets are static we can simply calculate the combined
         // offset statically.
         offsets.push_back(rewriter.getI64IntegerAttr(
-            cast<IntegerAttr>(opOffsetAttr).getInt() +
+            cast<IntegerAttr>(opOffsetAttr).getInt() *
+                cast<IntegerAttr>(sourceStrideAttr).getInt() +
             cast<IntegerAttr>(sourceOffsetAttr).getInt()));
       } else {
-        // When either offset is dynamic, we must emit an additional affine
-        // transformation to add the two offsets together dynamically.
-        AffineExpr expr = rewriter.getAffineConstantExpr(0);
+        AffineExpr expr0 = rewriter.getAffineConstantExpr(0);
+        AffineExpr expr1 = rewriter.getAffineConstantExpr(0);
         SmallVector<Value> affineApplyOperands;
-        for (auto valueOrAttr : {opOffset, sourceOffset}) {
-          if (auto attr = llvm::dyn_cast_if_present<Attribute>(valueOrAttr)) {
-            expr = expr + cast<IntegerAttr>(attr).getInt();
-          } else {
-            expr =
-                expr + rewriter.getAffineSymbolExpr(affineApplyOperands.size());
-            affineApplyOperands.push_back(valueOrAttr.get<Value>());
-          }
+
+        // Make 'expr0' add 'sourceOffset'.
+        if (auto attr = llvm::dyn_cast_if_present<Attribute>(sourceOffset)) {
+          expr0 = expr0 + cast<IntegerAttr>(attr).getInt();
+        } else {
+          expr0 =
+              expr0 + rewriter.getAffineSymbolExpr(affineApplyOperands.size());
+          affineApplyOperands.push_back(sourceOffset.get<Value>());
+        }
----------------
ftynse wrote:

Okay, so now we don't need `expr0 = expr0 + X` , we can just do `expr0 = X` in both branches. A better name than `expr0` is also beneficial ;)

https://github.com/llvm/llvm-project/pull/80551


More information about the Mlir-commits mailing list