[Mlir-commits] [mlir] [mlir]Fix compose subview (PR #80551)

lonely eagle llvmlistbot at llvm.org
Tue Feb 6 11:44:25 PST 2024


https://github.com/linuxlonelyeagle updated https://github.com/llvm/llvm-project/pull/80551

>From 64ad5c642a753eeec01d0054d233a83cfeab50b0 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <lonelyeagle.02 at bytedance.com>
Date: Sun, 4 Feb 2024 00:06:46 +0800
Subject: [PATCH 1/3] fix compose-subview.

---
 .../MemRef/Transforms/ComposeSubView.cpp      | 89 +++++++++++--------
 mlir/test/Transforms/compose-subview.mlir     | 48 ++++++++++
 2 files changed, 100 insertions(+), 37 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
index 431d270b0a2cb..dd10865544a50 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
@@ -24,14 +24,14 @@ using namespace mlir;
 
 namespace {
 
-// Replaces a subview of a subview with a single subview. Only supports subview
-// ops with static sizes and static strides of 1 (both static and dynamic
+// Replaces a subview of a subview with a single subview(both static and dynamic
 // offsets are supported).
 struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
   using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(memref::SubViewOp op,
                                 PatternRewriter &rewriter) const override {
+
     // 'op' is the 'SubViewOp' we're rewriting. 'sourceOp' is the op that
     // produces the input of the op we're rewriting (for 'SubViewOp' the input
     // is called the "source" value). We can only combine them if both 'op' and
@@ -52,66 +52,81 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
 
     // Offsets, sizes and strides OpFoldResult for the combined 'SubViewOp'.
     SmallVector<OpFoldResult> offsets, sizes, strides;
-
-    // Because we only support input strides of 1, the output stride is also
-    // always 1.
-    if (llvm::all_of(strides, [](OpFoldResult &valueOrAttr) {
-          Attribute attr = llvm::dyn_cast_if_present<Attribute>(valueOrAttr);
-          return attr && cast<IntegerAttr>(attr).getInt() == 1;
-        })) {
-      strides = SmallVector<OpFoldResult>(sourceOp.getMixedStrides().size(),
-                                          rewriter.getI64IntegerAttr(1));
-    } else {
-      return failure();
+    auto opStrides = op.getMixedStrides();
+    auto sourceStrides = sourceOp.getMixedStrides();
+
+    // The output stride in each dimension is equal to the product of the
+    // dimensions corresponding to source and op.
+    for (auto [opStride, sourceStride] : llvm::zip(opStrides, sourceStrides)) {
+      Attribute opStrideAttr = dyn_cast_if_present<Attribute>(opStride);
+      Attribute sourceStrideAttr = dyn_cast_if_present<Attribute>(sourceStride);
+      if (!opStrideAttr || !sourceStrideAttr)
+        return failure();
+      strides.push_back(rewriter.getI64IntegerAttr(
+          cast<IntegerAttr>(opStrideAttr).getInt() *
+          cast<IntegerAttr>(sourceStrideAttr).getInt()));
     }
 
     // The rules for calculating the new offsets and sizes are:
     // * Multiple subview offsets for a given dimension compose additively.
-    //   ("Offset by m" followed by "Offset by n" == "Offset by m + n")
+    //   ("Offset by m and Stride by k" followed by "Offset by n" == "Offset by
+    //   m + n * k")
     // * Multiple sizes for a given dimension compose by taking the size of the
     //   final subview and ignoring the rest. ("Take m values" followed by "Take
     //   n values" == "Take n values") This size must also be the smallest one
     //   by definition (a subview needs to be the same size as or smaller than
     //   its source along each dimension; presumably subviews that are larger
     //   than their sources are disallowed by validation).
-    for (auto it : llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(),
-                             op.getMixedSizes())) {
-      auto opOffset = std::get<0>(it);
-      auto sourceOffset = std::get<1>(it);
-      auto opSize = std::get<2>(it);
-
+    for (auto [opOffset, sourceOffset, sourceStride, opSize] :
+         llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(),
+                   sourceOp.getMixedStrides(), op.getMixedSizes())) {
       // We only support static sizes.
       if (opSize.is<Value>()) {
         return failure();
       }
-
       sizes.push_back(opSize);
       Attribute opOffsetAttr = llvm::dyn_cast_if_present<Attribute>(opOffset),
                 sourceOffsetAttr =
-                    llvm::dyn_cast_if_present<Attribute>(sourceOffset);
-
+                    llvm::dyn_cast_if_present<Attribute>(sourceOffset),
+                sourceStrideAttr =
+                    llvm::dyn_cast_if_present<Attribute>(sourceStride);
       if (opOffsetAttr && sourceOffsetAttr) {
+
         // If both offsets are static we can simply calculate the combined
         // offset statically.
         offsets.push_back(rewriter.getI64IntegerAttr(
-            cast<IntegerAttr>(opOffsetAttr).getInt() +
+            cast<IntegerAttr>(opOffsetAttr).getInt() *
+                cast<IntegerAttr>(sourceStrideAttr).getInt() +
             cast<IntegerAttr>(sourceOffsetAttr).getInt()));
       } else {
-        // When either offset is dynamic, we must emit an additional affine
-        // transformation to add the two offsets together dynamically.
-        AffineExpr expr = rewriter.getAffineConstantExpr(0);
+        AffineExpr expr0 = rewriter.getAffineConstantExpr(0);
+        AffineExpr expr1 = rewriter.getAffineConstantExpr(0);
         SmallVector<Value> affineApplyOperands;
-        for (auto valueOrAttr : {opOffset, sourceOffset}) {
-          if (auto attr = llvm::dyn_cast_if_present<Attribute>(valueOrAttr)) {
-            expr = expr + cast<IntegerAttr>(attr).getInt();
+        SmallVector<OpFoldResult> opOffsets{sourceOffset, opOffset};
+        for (auto [idx, offset] : llvm::enumerate(opOffsets)) {
+          if (auto attr = llvm::dyn_cast_if_present<Attribute>(offset)) {
+            if (idx == 0) {
+              expr0 = expr0 + cast<IntegerAttr>(attr).getInt();
+            } else if (idx == 1) {
+              expr1 = expr1 + cast<IntegerAttr>(attr).getInt();
+              expr1 = expr1 * cast<IntegerAttr>(sourceStrideAttr).getInt();
+              expr0 = expr0 + expr1;
+            }
           } else {
-            expr =
-                expr + rewriter.getAffineSymbolExpr(affineApplyOperands.size());
-            affineApplyOperands.push_back(valueOrAttr.get<Value>());
+            if (idx == 0) {
+              expr0 = expr0 +
+                      rewriter.getAffineSymbolExpr(affineApplyOperands.size());
+              affineApplyOperands.push_back(offset.get<Value>());
+            } else if (idx == 1) {
+              expr1 = expr1 +
+                      rewriter.getAffineSymbolExpr(affineApplyOperands.size());
+              affineApplyOperands.push_back(offset.get<Value>());
+              expr1 = expr1 * cast<IntegerAttr>(sourceStrideAttr).getInt();
+              expr0 = expr0 + expr1;
+            }
           }
         }
-
-        AffineMap map = AffineMap::get(0, affineApplyOperands.size(), expr);
+        AffineMap map = AffineMap::get(0, affineApplyOperands.size(), expr0);
         Value result = rewriter.create<affine::AffineApplyOp>(
             op.getLoc(), map, affineApplyOperands);
         offsets.push_back(result);
@@ -120,8 +135,8 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
 
     // This replaces 'op' but leaves 'sourceOp' alone; if it no longer has any
     // uses it can be removed by a (separate) dead code elimination pass.
-    rewriter.replaceOpWithNewOp<memref::SubViewOp>(op, sourceOp.getSource(),
-                                                   offsets, sizes, strides);
+    rewriter.replaceOpWithNewOp<memref::SubViewOp>(
+        op, op.getType(), sourceOp.getSource(), offsets, sizes, strides);
     return success();
   }
 };
diff --git a/mlir/test/Transforms/compose-subview.mlir b/mlir/test/Transforms/compose-subview.mlir
index cb8ebcb2bf9e9..2fa551c218d42 100644
--- a/mlir/test/Transforms/compose-subview.mlir
+++ b/mlir/test/Transforms/compose-subview.mlir
@@ -45,3 +45,51 @@ func.func @main(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024,
   %1 = memref.subview %0[1, %cst_128] [1, 128] [1, 1] : memref<2x256xf32, strided<[1024, 1], offset: ?>> to memref<1x128xf32, strided<[1024, 1], offset: ?>>
   return %1 : memref<1x128xf32, strided<[1024, 1], offset: ?>>
 }
+
+// -----
+
+func.func @main(%input: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: 4480>> {
+  //      CHECK: subview %arg0[4, 384] [1, 64] [4, 4] 
+  // CHECK-SAME: memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: 4480>>
+  %0 = memref.subview %input[2, 256] [2, 256] [2, 2] : memref<4x1024xf32> to memref<2x256xf32, strided<[2048, 2], offset: 2304>>
+  %1 = memref.subview %0[1, 64] [1, 64] [2, 2] : memref<2x256xf32, strided<[2048, 2], offset: 2304>> to memref<1x64xf32, strided<[4096, 4], offset: 4480>>
+  return %1 : memref<1x64xf32, strided<[4096, 4], offset: 4480>>
+}
+
+// -----
+
+func.func @main(%input: memref<30x30xf32>) -> memref<2x2xf32, strided<[240, 8], offset: 217>> {
+  //      CHECK: subview %arg0[7, 7] [2, 2] [8, 8]
+  // CHECK-SAME: memref<30x30xf32> to memref<2x2xf32, strided<[240, 8], offset: 217>>
+  %0 = memref.subview %input[1, 1] [12, 12] [2, 2] : memref<30x30xf32> to memref<12x12xf32, strided<[60, 2], offset: 31>>
+  %1 = memref.subview %0[1, 1] [5, 5] [2, 2] : memref<12x12xf32, strided<[60, 2], offset: 31>> to memref<5x5xf32, strided<[120, 4], offset: 93>>
+  %2 = memref.subview %1[1, 1] [2, 2] [2, 2] : memref<5x5xf32, strided<[120, 4], offset: 93>> to memref<2x2xf32, strided<[240, 8], offset: 217>>
+  return %2 : memref<2x2xf32, strided<[240, 8], offset: 217>> 
+}
+
+// -----
+
+func.func @main(%input: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> {
+  //      CHECK:%[[VAL_1:.*]] = arith.constant 4 : index
+  %cst_2 = arith.constant 2 : index
+  //      CHECK:%[[VAL_2:.*]] = arith.constant 384 : index
+  %cst_64 = arith.constant 64 : index
+  //      CHECK: subview %arg0{{\[}}%[[VAL_1]], %[[VAL_2]]] [1, 64] [4, 4]
+  // CHECK-SAME: memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
+  %0 = memref.subview %input[%cst_2, 256] [2, 256] [2, 2] : memref<4x1024xf32> to memref<2x256xf32, strided<[2048, 2], offset: ?>>
+  %1 = memref.subview %0[1, %cst_64] [1, 64] [2, 2] : memref<2x256xf32, strided<[2048, 2], offset: ?>> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
+  return %1 : memref<1x64xf32, strided<[4096, 4], offset: ?>>
+}
+
+// -----
+
+func.func @main(%input: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> {
+  //      CHECK: %[[VAL_1:.*]] = arith.constant 4 : index
+  %cst_1 = arith.constant 1 : index
+  %cst_2 = arith.constant 2 : index
+  //      CHECK: subview %arg0{{\[}}%[[VAL_1]], 384] [1, 64] [4, 4]
+  // CHECK-SAME: memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
+  %0 = memref.subview %input[%cst_2, 256] [2, 256] [2, 2] : memref<4x1024xf32> to memref<2x256xf32, strided<[2048, 2], offset: ?>>
+  %1 = memref.subview %0[%cst_1, 64] [1, 64] [2, 2] : memref<2x256xf32, strided<[2048, 2], offset: ?>> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
+  return %1 : memref<1x64xf32, strided<[4096, 4], offset: ?>>
+}

>From 1632e3a12c2c3e06c78877865f4cba32164dd23b Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <lonelyeagle.02 at bytedance.com>
Date: Tue, 6 Feb 2024 11:00:29 +0800
Subject: [PATCH 2/3] fix nit and update test.

---
 .../MemRef/Transforms/ComposeSubView.cpp      | 18 ++---
 mlir/test/Transforms/compose-subview.mlir     | 68 +++++++++++--------
 2 files changed, 47 insertions(+), 39 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
index dd10865544a50..47c6d01f3fda2 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
@@ -31,7 +31,6 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
 
   LogicalResult matchAndRewrite(memref::SubViewOp op,
                                 PatternRewriter &rewriter) const override {
-
     // 'op' is the 'SubViewOp' we're rewriting. 'sourceOp' is the op that
     // produces the input of the op we're rewriting (for 'SubViewOp' the input
     // is called the "source" value). We can only combine them if both 'op' and
@@ -51,13 +50,14 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
     }
 
     // Offsets, sizes and strides OpFoldResult for the combined 'SubViewOp'.
-    SmallVector<OpFoldResult> offsets, sizes, strides;
-    auto opStrides = op.getMixedStrides();
-    auto sourceStrides = sourceOp.getMixedStrides();
+    SmallVector<OpFoldResult> offsets, sizes, strides,
+        opStrides = op.getMixedStrides(),
+        sourceStrides = sourceOp.getMixedStrides();
 
     // The output stride in each dimension is equal to the product of the
     // dimensions corresponding to source and op.
-    for (auto [opStride, sourceStride] : llvm::zip(opStrides, sourceStrides)) {
+    for (auto &&[opStride, sourceStride] :
+         llvm::zip(opStrides, sourceStrides)) {
       Attribute opStrideAttr = dyn_cast_if_present<Attribute>(opStride);
       Attribute sourceStrideAttr = dyn_cast_if_present<Attribute>(sourceStride);
       if (!opStrideAttr || !sourceStrideAttr)
@@ -77,7 +77,7 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
     //   by definition (a subview needs to be the same size as or smaller than
     //   its source along each dimension; presumably subviews that are larger
     //   than their sources are disallowed by validation).
-    for (auto [opOffset, sourceOffset, sourceStride, opSize] :
+    for (auto &&[opOffset, sourceOffset, sourceStride, opSize] :
          llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(),
                    sourceOp.getMixedStrides(), op.getMixedSizes())) {
       // We only support static sizes.
@@ -103,13 +103,13 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
         AffineExpr expr1 = rewriter.getAffineConstantExpr(0);
         SmallVector<Value> affineApplyOperands;
         SmallVector<OpFoldResult> opOffsets{sourceOffset, opOffset};
-        for (auto [idx, offset] : llvm::enumerate(opOffsets)) {
+        for (auto &&[idx, offset] : llvm::enumerate(opOffsets)) {
           if (auto attr = llvm::dyn_cast_if_present<Attribute>(offset)) {
             if (idx == 0) {
               expr0 = expr0 + cast<IntegerAttr>(attr).getInt();
             } else if (idx == 1) {
-              expr1 = expr1 + cast<IntegerAttr>(attr).getInt();
-              expr1 = expr1 * cast<IntegerAttr>(sourceStrideAttr).getInt();
+              expr1 = expr1 + cast<IntegerAttr>(attr).getInt() *
+                                  cast<IntegerAttr>(sourceStrideAttr).getInt();
               expr0 = expr0 + expr1;
             }
           } else {
diff --git a/mlir/test/Transforms/compose-subview.mlir b/mlir/test/Transforms/compose-subview.mlir
index 2fa551c218d42..22ffd836c68ed 100644
--- a/mlir/test/Transforms/compose-subview.mlir
+++ b/mlir/test/Transforms/compose-subview.mlir
@@ -1,8 +1,9 @@
 // RUN: mlir-opt %s -test-compose-subview -split-input-file | FileCheck %s
 
-func.func @main(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: 3456>> {
-  //      CHECK: subview %arg0[3, 384] [1, 128] [1, 1]
-  // CHECK-SAME: memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: 3456>>
+// CHECK-LABEL: func.func @subview_strided(
+// CHECK-SAME:  %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: 3456>> {
+func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: 3456>> {
+  // CHECK: %[[VAL_1:.*]] = memref.subview %[[VAL_0]][3, 384] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: 3456>>
   %0 = memref.subview %input[2, 256] [2, 256] [1, 1] : memref<4x1024xf32> to memref<2x256xf32, strided<[1024, 1], offset: 2304>>
   %1 = memref.subview %0[1, 128] [1, 128] [1, 1] : memref<2x256xf32, strided<[1024, 1], offset: 2304>> to memref<1x128xf32, strided<[1024, 1], offset: 3456>>
   return %1 : memref<1x128xf32, strided<[1024, 1], offset: 3456>>
@@ -10,9 +11,10 @@ func.func @main(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024,
 
 // -----
 
-func.func @main(%input: memref<4x1024xf32>) -> memref<1x10xf32, strided<[1024, 1], offset: 3745>> {
-  //      CHECK: subview %arg0[3, 673] [1, 10] [1, 1]
-  // CHECK-SAME: memref<4x1024xf32> to memref<1x10xf32, strided<[1024, 1], offset: 3745>>
+// CHECK-LABEL: func.func @subview_strided(
+// CHECK-SAME:  %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x10xf32, strided<[1024, 1], offset: 3745>> {
+func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x10xf32, strided<[1024, 1], offset: 3745>> {
+  // CHECK:     %[[VAL_1:.*]] = memref.subview %[[VAL_0]][3, 673] [1, 10] [1, 1] : memref<4x1024xf32> to memref<1x10xf32, strided<[1024, 1], offset: 3745>>
   %0 = memref.subview %input[1, 512] [3, 256] [1, 1] : memref<4x1024xf32> to memref<3x256xf32, strided<[1024, 1], offset: 1536>>
   %1 = memref.subview %0[1, 128] [2, 128] [1, 1] : memref<3x256xf32, strided<[1024, 1], offset: 1536>> to memref<2x128xf32, strided<[1024, 1], offset: 2688>>
   %2 = memref.subview %1[1, 33] [1, 10] [1, 1] : memref<2x128xf32, strided<[1024, 1], offset: 2688>> to memref<1x10xf32, strided<[1024, 1], offset: 3745>>
@@ -21,12 +23,13 @@ func.func @main(%input: memref<4x1024xf32>) -> memref<1x10xf32, strided<[1024, 1
 
 // -----
 
-func.func @main(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> {
-  // CHECK: [[CST_3:%.*]] = arith.constant 3 : index
+// CHECK-LABEL: func.func @subview_strided(
+// CHECK-SAME:  %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> {
+func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> {
+  // CHECK: %[[VAL_1:.*]] = arith.constant 3 : index
   %cst_1 = arith.constant 1 : index
   %cst_2 = arith.constant 2 : index
-  //      CHECK: subview %arg0{{\[}}[[CST_3]], 384] [1, 128] [1, 1]
-  // CHECK-SAME: memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: ?>>
+  // CHECK: %[[VAL_2:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_1]], 384] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: ?>>
   %0 = memref.subview %input[%cst_2, 256] [2, 256] [1, 1] : memref<4x1024xf32> to memref<2x256xf32, strided<[1024, 1], offset: ?>>
   %1 = memref.subview %0[%cst_1, 128] [1, 128] [1, 1] : memref<2x256xf32, strided<[1024, 1], offset: ?>> to memref<1x128xf32, strided<[1024, 1], offset: ?>>
   return %1 : memref<1x128xf32, strided<[1024, 1], offset: ?>>
@@ -34,13 +37,14 @@ func.func @main(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024,
 
 // -----
 
-func.func @main(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> {
-  // CHECK: [[CST_3:%.*]] = arith.constant 3 : index
+// CHECK-LABEL: func.func @subview_strided(
+// CHECK-SAME:  %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> {
+func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> {
+  // CHECK: %[[VAL_1:.*]] = arith.constant 3 : index
   %cst_2 = arith.constant 2 : index
-  // CHECK: [[CST_384:%.*]] = arith.constant 384 : index
+  // CHECK: %[[VAL_2:.*]] = arith.constant 384 : index
   %cst_128 = arith.constant 128 : index
-  //      CHECK: subview %arg0{{\[}}[[CST_3]], [[CST_384]]] [1, 128] [1, 1]
-  // CHECK-SAME: memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: ?>>
+  // CHECK: %[[VAL_3:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_1]], %[[VAL_2]]] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: ?>>
   %0 = memref.subview %input[%cst_2, 256] [2, 256] [1, 1] : memref<4x1024xf32> to memref<2x256xf32, strided<[1024, 1], offset: ?>>
   %1 = memref.subview %0[1, %cst_128] [1, 128] [1, 1] : memref<2x256xf32, strided<[1024, 1], offset: ?>> to memref<1x128xf32, strided<[1024, 1], offset: ?>>
   return %1 : memref<1x128xf32, strided<[1024, 1], offset: ?>>
@@ -48,9 +52,10 @@ func.func @main(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024,
 
 // -----
 
-func.func @main(%input: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: 4480>> {
-  //      CHECK: subview %arg0[4, 384] [1, 64] [4, 4] 
-  // CHECK-SAME: memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: 4480>>
+// CHECK-LABEL: func.func @subview_strided(
+// CHECK-SAME:  %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: 4480>> {
+func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: 4480>> {
+  // CHECK:     %[[VAL_1:.*]] = memref.subview %[[VAL_0]][4, 384] [1, 64] [4, 4] : memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: 4480>>
   %0 = memref.subview %input[2, 256] [2, 256] [2, 2] : memref<4x1024xf32> to memref<2x256xf32, strided<[2048, 2], offset: 2304>>
   %1 = memref.subview %0[1, 64] [1, 64] [2, 2] : memref<2x256xf32, strided<[2048, 2], offset: 2304>> to memref<1x64xf32, strided<[4096, 4], offset: 4480>>
   return %1 : memref<1x64xf32, strided<[4096, 4], offset: 4480>>
@@ -58,9 +63,10 @@ func.func @main(%input: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4
 
 // -----
 
-func.func @main(%input: memref<30x30xf32>) -> memref<2x2xf32, strided<[240, 8], offset: 217>> {
-  //      CHECK: subview %arg0[7, 7] [2, 2] [8, 8]
-  // CHECK-SAME: memref<30x30xf32> to memref<2x2xf32, strided<[240, 8], offset: 217>>
+// CHECK-LABEL: func.func @subview_strided(
+// CHECK-SAME:  %[[VAL_0:.*]]: memref<30x30xf32>) -> memref<2x2xf32, strided<[240, 8], offset: 217>> {
+func.func @subview_strided(%input: memref<30x30xf32>) -> memref<2x2xf32, strided<[240, 8], offset: 217>> {
+  // CHECK:     %[[VAL_1:.*]] = memref.subview %[[VAL_0]][7, 7] [2, 2] [8, 8] : memref<30x30xf32> to memref<2x2xf32, strided<[240, 8], offset: 217>>
   %0 = memref.subview %input[1, 1] [12, 12] [2, 2] : memref<30x30xf32> to memref<12x12xf32, strided<[60, 2], offset: 31>>
   %1 = memref.subview %0[1, 1] [5, 5] [2, 2] : memref<12x12xf32, strided<[60, 2], offset: 31>> to memref<5x5xf32, strided<[120, 4], offset: 93>>
   %2 = memref.subview %1[1, 1] [2, 2] [2, 2] : memref<5x5xf32, strided<[120, 4], offset: 93>> to memref<2x2xf32, strided<[240, 8], offset: 217>>
@@ -69,13 +75,14 @@ func.func @main(%input: memref<30x30xf32>) -> memref<2x2xf32, strided<[240, 8],
 
 // -----
 
-func.func @main(%input: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> {
-  //      CHECK:%[[VAL_1:.*]] = arith.constant 4 : index
+// CHECK-LABEL: func.func @subview_strided(
+// CHECK-SAME:  %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> {
+func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> {
+  // CHECK:     %[[VAL_1:.*]] = arith.constant 4 : index
   %cst_2 = arith.constant 2 : index
-  //      CHECK:%[[VAL_2:.*]] = arith.constant 384 : index
+  // CHECK:     %[[VAL_2:.*]] = arith.constant 384 : index
   %cst_64 = arith.constant 64 : index
-  //      CHECK: subview %arg0{{\[}}%[[VAL_1]], %[[VAL_2]]] [1, 64] [4, 4]
-  // CHECK-SAME: memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
+  // CHECK:     %[[VAL_3:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_1]], %[[VAL_2]]] [1, 64] [4, 4] : memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
   %0 = memref.subview %input[%cst_2, 256] [2, 256] [2, 2] : memref<4x1024xf32> to memref<2x256xf32, strided<[2048, 2], offset: ?>>
   %1 = memref.subview %0[1, %cst_64] [1, 64] [2, 2] : memref<2x256xf32, strided<[2048, 2], offset: ?>> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
   return %1 : memref<1x64xf32, strided<[4096, 4], offset: ?>>
@@ -83,12 +90,13 @@ func.func @main(%input: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4
 
 // -----
 
-func.func @main(%input: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> {
-  //      CHECK: %[[VAL_1:.*]] = arith.constant 4 : index
+// CHECK-LABEL: func.func @subview_strided(
+// CHECK-SAME:  %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> {
+func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> {
+  // CHECK:     %[[VAL_1:.*]] = arith.constant 4 : index
   %cst_1 = arith.constant 1 : index
   %cst_2 = arith.constant 2 : index
-  //      CHECK: subview %arg0{{\[}}%[[VAL_1]], 384] [1, 64] [4, 4]
-  // CHECK-SAME: memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
+  // CHECK:     %[[VAL_2:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_1]], 384] [1, 64] [4, 4] : memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
   %0 = memref.subview %input[%cst_2, 256] [2, 256] [2, 2] : memref<4x1024xf32> to memref<2x256xf32, strided<[2048, 2], offset: ?>>
   %1 = memref.subview %0[%cst_1, 64] [1, 64] [2, 2] : memref<2x256xf32, strided<[2048, 2], offset: ?>> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
   return %1 : memref<1x64xf32, strided<[4096, 4], offset: ?>>

>From ac87e549355543d6cd151065df3719af36c1fc44 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <lonelyeagle.02 at bytedance.com>
Date: Wed, 7 Feb 2024 03:43:05 +0800
Subject: [PATCH 3/3] Eliminating unnecessary loops.

---
 .../MemRef/Transforms/ComposeSubView.cpp      | 46 +++++++++----------
 1 file changed, 23 insertions(+), 23 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
index 47c6d01f3fda2..aaa9e3740bc7c 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
@@ -102,30 +102,30 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
         AffineExpr expr0 = rewriter.getAffineConstantExpr(0);
         AffineExpr expr1 = rewriter.getAffineConstantExpr(0);
         SmallVector<Value> affineApplyOperands;
-        SmallVector<OpFoldResult> opOffsets{sourceOffset, opOffset};
-        for (auto &&[idx, offset] : llvm::enumerate(opOffsets)) {
-          if (auto attr = llvm::dyn_cast_if_present<Attribute>(offset)) {
-            if (idx == 0) {
-              expr0 = expr0 + cast<IntegerAttr>(attr).getInt();
-            } else if (idx == 1) {
-              expr1 = expr1 + cast<IntegerAttr>(attr).getInt() *
-                                  cast<IntegerAttr>(sourceStrideAttr).getInt();
-              expr0 = expr0 + expr1;
-            }
-          } else {
-            if (idx == 0) {
-              expr0 = expr0 +
-                      rewriter.getAffineSymbolExpr(affineApplyOperands.size());
-              affineApplyOperands.push_back(offset.get<Value>());
-            } else if (idx == 1) {
-              expr1 = expr1 +
-                      rewriter.getAffineSymbolExpr(affineApplyOperands.size());
-              affineApplyOperands.push_back(offset.get<Value>());
-              expr1 = expr1 * cast<IntegerAttr>(sourceStrideAttr).getInt();
-              expr0 = expr0 + expr1;
-            }
-          }
+
+        // Make 'expr0' add 'sourceOffset'.
+        if (auto attr = llvm::dyn_cast_if_present<Attribute>(sourceOffset)) {
+          expr0 = expr0 + cast<IntegerAttr>(attr).getInt();
+        } else {
+          expr0 =
+              expr0 + rewriter.getAffineSymbolExpr(affineApplyOperands.size());
+          affineApplyOperands.push_back(sourceOffset.get<Value>());
+        }
+
+        // Multiply 'opOffset' by 'sourceStride' and make the 'expr0' add the
+        // result.
+        if (auto attr = llvm::dyn_cast_if_present<Attribute>(opOffset)) {
+          expr1 = expr1 + cast<IntegerAttr>(attr).getInt() *
+                              cast<IntegerAttr>(sourceStrideAttr).getInt();
+          expr0 = expr0 + expr1;
+        } else {
+          expr1 =
+              expr1 + rewriter.getAffineSymbolExpr(affineApplyOperands.size());
+          affineApplyOperands.push_back(opOffset.get<Value>());
+          expr1 = expr1 * cast<IntegerAttr>(sourceStrideAttr).getInt();
+          expr0 = expr0 + expr1;
         }
+
         AffineMap map = AffineMap::get(0, affineApplyOperands.size(), expr0);
         Value result = rewriter.create<affine::AffineApplyOp>(
             op.getLoc(), map, affineApplyOperands);



More information about the Mlir-commits mailing list