[Mlir-commits] [mlir] [mlir]Fix compose subview (PR #80551)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Tue Feb 6 11:58:50 PST 2024


================
@@ -51,67 +50,83 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
     }
 
     // Offsets, sizes and strides OpFoldResult for the combined 'SubViewOp'.
-    SmallVector<OpFoldResult> offsets, sizes, strides;
-
-    // Because we only support input strides of 1, the output stride is also
-    // always 1.
-    if (llvm::all_of(strides, [](OpFoldResult &valueOrAttr) {
-          Attribute attr = llvm::dyn_cast_if_present<Attribute>(valueOrAttr);
-          return attr && cast<IntegerAttr>(attr).getInt() == 1;
-        })) {
-      strides = SmallVector<OpFoldResult>(sourceOp.getMixedStrides().size(),
-                                          rewriter.getI64IntegerAttr(1));
-    } else {
-      return failure();
+    SmallVector<OpFoldResult> offsets, sizes, strides,
+        opStrides = op.getMixedStrides(),
+        sourceStrides = sourceOp.getMixedStrides();
+
+    // The output stride in each dimension is equal to the product of the
+    // dimensions corresponding to source and op.
+    for (auto &&[opStride, sourceStride] :
+         llvm::zip(opStrides, sourceStrides)) {
+      Attribute opStrideAttr = dyn_cast_if_present<Attribute>(opStride);
+      Attribute sourceStrideAttr = dyn_cast_if_present<Attribute>(sourceStride);
+      if (!opStrideAttr || !sourceStrideAttr)
+        return failure();
+      strides.push_back(rewriter.getI64IntegerAttr(
+          cast<IntegerAttr>(opStrideAttr).getInt() *
+          cast<IntegerAttr>(sourceStrideAttr).getInt()));
----------------
ftynse wrote:

Nit: let's just cast once and extract the value instead of repeating these casts several times below.

https://github.com/llvm/llvm-project/pull/80551


More information about the Mlir-commits mailing list