[Mlir-commits] [mlir] 3ba1b1c - Add a pattern to combine composed subview ops

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 1 10:57:15 PDT 2021


Author: Aden Grue
Date: 2021-04-01T10:56:57-07:00
New Revision: 3ba1b1cd201dbf2d5f9da4d3a018091c3e3a2d78

URL: https://github.com/llvm/llvm-project/commit/3ba1b1cd201dbf2d5f9da4d3a018091c3e3a2d78
DIFF: https://github.com/llvm/llvm-project/commit/3ba1b1cd201dbf2d5f9da4d3a018091c3e3a2d78.diff

LOG: Add a pattern to combine composed subview ops

Differential Revision: https://reviews.llvm.org/D99229

Added: 
    mlir/include/mlir/Dialect/StandardOps/Transforms/ComposeSubView.h
    mlir/lib/Dialect/StandardOps/Transforms/ComposeSubView.cpp
    mlir/test/Transforms/compose-subview.mlir
    mlir/test/lib/Transforms/TestComposeSubView.cpp

Modified: 
    mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
    mlir/test/lib/Transforms/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/ComposeSubView.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/ComposeSubView.h
new file mode 100644
index 0000000000000..7a5ae3e8417b7
--- /dev/null
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/ComposeSubView.h
@@ -0,0 +1,28 @@
+//===- ComposeSubView.h - Combining composed subview ops --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Patterns for combining composed subview ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_STANDARDOPS_TRANSFORMS_COMPOSESUBVIEW_H_
+#define MLIR_DIALECT_STANDARDOPS_TRANSFORMS_COMPOSESUBVIEW_H_
+
+namespace mlir {
+
+// Forward declarations.
+class MLIRContext;
+class RewritePatternSet;
+using OwningRewritePatternList = RewritePatternSet;
+
+void populateComposeSubViewPatterns(OwningRewritePatternList &patterns,
+                                    MLIRContext *context);
+
+} // namespace mlir
+
+#endif // MLIR_DIALECT_STANDARDOPS_TRANSFORMS_COMPOSESUBVIEW_H_

diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
index b01ff954d7f3c..2fa0b96bed7ab 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRStandardOpsTransforms
   Bufferize.cpp
+  ComposeSubView.cpp
   DecomposeCallGraphTypes.cpp
   ExpandOps.cpp
   FuncBufferize.cpp

diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/ComposeSubView.cpp b/mlir/lib/Dialect/StandardOps/Transforms/ComposeSubView.cpp
new file mode 100644
index 0000000000000..cabaa614fa2a9
--- /dev/null
+++ b/mlir/lib/Dialect/StandardOps/Transforms/ComposeSubView.cpp
@@ -0,0 +1,136 @@
+//===- ComposeSubView.cpp - Combining composed subview ops ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains patterns for combining composed subview ops (i.e. subview
+// of a subview becomes a single subview).
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/StandardOps/Transforms/ComposeSubView.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+
+namespace {
+
+// Replaces a subview of a subview with a single subview. Only supports subview
+// ops with static sizes and static strides of 1 (both static and dynamic
+// offsets are supported).
+struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(memref::SubViewOp op,
+                                PatternRewriter &rewriter) const override {
+    // 'op' is the 'SubViewOp' we're rewriting. 'sourceOp' is the op that
+    // produces the input of the op we're rewriting (for 'SubViewOp' the input
+    // is called the "source" value). We can only combine them if both 'op' and
+    // 'sourceOp' are 'SubViewOp'.
+    auto sourceOp = op.source().getDefiningOp<memref::SubViewOp>();
+    if (!sourceOp)
+      return failure();
+
+    // A 'SubViewOp' can be "rank-reducing" by eliminating dimensions of the
+    // output memref that are statically known to be equal to 1. We do not
+    // allow 'sourceOp' to be a rank-reducing subview because then our two
+    // 'SubViewOp's would have 
diff erent numbers of offset/size/stride
+    // parameters (just 
diff icult to deal with, not impossible if we end up
+    // needing it).
+    if (sourceOp.getSourceType().getRank() != sourceOp.getType().getRank()) {
+      return failure();
+    }
+
+    // Offsets, sizes and strides OpFoldResult for the combined 'SubViewOp'.
+    SmallVector<OpFoldResult> offsets, sizes, strides;
+
+    // Because we only support input strides of 1, the output stride is also
+    // always 1.
+    if (llvm::all_of(strides, [](OpFoldResult &valueOrAttr) {
+          Attribute attr = valueOrAttr.dyn_cast<Attribute>();
+          return attr && attr.cast<IntegerAttr>().getInt() == 1;
+        })) {
+      strides = SmallVector<OpFoldResult>(sourceOp.getMixedStrides().size(),
+                                          rewriter.getI64IntegerAttr(1));
+    } else {
+      return failure();
+    }
+
+    // The rules for calculating the new offsets and sizes are:
+    // * Multiple subview offsets for a given dimension compose additively.
+    //   ("Offset by m" followed by "Offset by n" == "Offset by m + n")
+    // * Multiple sizes for a given dimension compose by taking the size of the
+    //   final subview and ignoring the rest. ("Take m values" followed by "Take
+    //   n values" == "Take n values") This size must also be the smallest one
+    //   by definition (a subview needs to be the same size as or smaller than
+    //   its source along each dimension; presumably subviews that are larger
+    //   than their sources are disallowed by validation).
+    for (auto it : llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(),
+                             op.getMixedSizes())) {
+      auto opOffset = std::get<0>(it);
+      auto sourceOffset = std::get<1>(it);
+      auto opSize = std::get<2>(it);
+
+      // We only support static sizes.
+      if (opSize.is<Value>()) {
+        return failure();
+      }
+
+      sizes.push_back(opSize);
+      Attribute opOffsetAttr = opOffset.dyn_cast<Attribute>(),
+                sourceOffsetAttr = sourceOffset.dyn_cast<Attribute>();
+
+      if (opOffsetAttr && sourceOffsetAttr) {
+        // If both offsets are static we can simply calculate the combined
+        // offset statically.
+        offsets.push_back(rewriter.getI64IntegerAttr(
+            opOffsetAttr.cast<IntegerAttr>().getInt() +
+            sourceOffsetAttr.cast<IntegerAttr>().getInt()));
+      } else {
+        // When either offset is dynamic, we must emit an additional affine
+        // transformation to add the two offsets together dynamically.
+        AffineExpr expr = rewriter.getAffineConstantExpr(0);
+        SmallVector<Value> affineApplyOperands;
+        for (auto valueOrAttr : {opOffset, sourceOffset}) {
+          if (auto attr = valueOrAttr.dyn_cast<Attribute>()) {
+            expr = expr + attr.cast<IntegerAttr>().getInt();
+          } else {
+            expr =
+                expr + rewriter.getAffineSymbolExpr(affineApplyOperands.size());
+            affineApplyOperands.push_back(valueOrAttr.get<Value>());
+          }
+        }
+
+        AffineMap map = AffineMap::get(0, affineApplyOperands.size(), expr);
+        Value result = rewriter.create<AffineApplyOp>(op.getLoc(), map,
+                                                      affineApplyOperands);
+        offsets.push_back(result);
+      }
+    }
+
+    // This replaces 'op' but leaves 'sourceOp' alone; if it no longer has any
+    // uses it can be removed by a (separate) dead code elimination pass.
+    rewriter.replaceOpWithNewOp<memref::SubViewOp>(op, sourceOp.source(),
+                                                   offsets, sizes, strides);
+    return success();
+  }
+};
+
+} // namespace
+
+void populateComposeSubViewPatterns(OwningRewritePatternList &patterns,
+                                    MLIRContext *context) {
+  patterns.insert<ComposeSubViewOpPattern>(context);
+}
+
+} // namespace mlir

diff  --git a/mlir/test/Transforms/compose-subview.mlir b/mlir/test/Transforms/compose-subview.mlir
new file mode 100644
index 0000000000000..9081ba232d175
--- /dev/null
+++ b/mlir/test/Transforms/compose-subview.mlir
@@ -0,0 +1,62 @@
+// RUN: mlir-opt -allow-unregistered-dialect %s -test-compose-subview -split-input-file | FileCheck %s
+
+// CHECK: [[MAP:#.*]] = affine_map<(d0, d1) -> (d0 * 1024 + d1 + 3456)
+#map0 = affine_map<(d0, d1) -> (d0 * 1024 + d1 + 2304)>
+#map1 = affine_map<(d0, d1) -> (d0 * 1024 + d1 + 3456)>
+
+func @main(%input: memref<4x1024xf32>) -> memref<1x128xf32, #map1> {
+  //      CHECK: subview %arg0[3, 384] [1, 128] [1, 1]
+  // CHECK-SAME: memref<4x1024xf32> to memref<1x128xf32, [[MAP]]>
+  %0 = memref.subview %input[2, 256] [2, 256] [1, 1] : memref<4x1024xf32> to memref<2x256xf32, #map0>
+  %1 = memref.subview %0[1, 128] [1, 128] [1, 1] : memref<2x256xf32, #map0> to memref<1x128xf32, #map1>
+  return %1 : memref<1x128xf32, #map1>
+}
+
+// -----
+
+// CHECK: [[MAP:#.*]] = affine_map<(d0, d1) -> (d0 * 1024 + d1 + 3745)
+#map0 = affine_map<(d0, d1) -> (d0 * 1024 + d1 + 1536)>
+#map1 = affine_map<(d0, d1) -> (d0 * 1024 + d1 + 2688)>
+#map2 = affine_map<(d0, d1) -> (d0 * 1024 + d1 + 3745)>
+
+func @main(%input: memref<4x1024xf32>) -> memref<1x10xf32, #map2> {
+  //      CHECK: subview %arg0[3, 673] [1, 10] [1, 1]
+  // CHECK-SAME: memref<4x1024xf32> to memref<1x10xf32, [[MAP]]>
+  %0 = memref.subview %input[1, 512] [3, 256] [1, 1] : memref<4x1024xf32> to memref<3x256xf32, #map0>
+  %1 = memref.subview %0[1, 128] [2, 128] [1, 1] : memref<3x256xf32, #map0> to memref<2x128xf32, #map1>
+  %2 = memref.subview %1[1, 33] [1, 10] [1, 1] : memref<2x128xf32, #map1> to memref<1x10xf32, #map2>
+  return %2 : memref<1x10xf32, #map2>
+}
+
+// -----
+
+// CHECK: [[MAP:#.*]] = affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)
+#map = affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>
+
+func @main(%input: memref<4x1024xf32>) -> memref<1x128xf32, #map> {
+  // CHECK: [[CST_3:%.*]] = constant 3 : index
+  %cst_1 = constant 1 : index
+  %cst_2 = constant 2 : index
+  //      CHECK: subview %arg0{{\[}}[[CST_3]], 384] [1, 128] [1, 1]
+  // CHECK-SAME: memref<4x1024xf32> to memref<1x128xf32, [[MAP]]>
+  %0 = memref.subview %input[%cst_2, 256] [2, 256] [1, 1] : memref<4x1024xf32> to memref<2x256xf32, #map>
+  %1 = memref.subview %0[%cst_1, 128] [1, 128] [1, 1] : memref<2x256xf32, #map> to memref<1x128xf32, #map>
+  return %1 : memref<1x128xf32, #map>
+}
+
+// -----
+
+// CHECK: [[MAP:#.*]] = affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)
+#map = affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>
+
+func @main(%input: memref<4x1024xf32>) -> memref<1x128xf32, #map> {
+  // CHECK: [[CST_3:%.*]] = constant 3 : index
+  %cst_2 = constant 2 : index
+  // CHECK: [[CST_384:%.*]] = constant 384 : index
+  %cst_128 = constant 128 : index
+  //      CHECK: subview %arg0{{\[}}[[CST_3]], [[CST_384]]] [1, 128] [1, 1]
+  // CHECK-SAME: memref<4x1024xf32> to memref<1x128xf32, [[MAP]]>
+  %0 = memref.subview %input[%cst_2, 256] [2, 256] [1, 1] : memref<4x1024xf32> to memref<2x256xf32, #map>
+  %1 = memref.subview %0[1, %cst_128] [1, 128] [1, 1] : memref<2x256xf32, #map> to memref<1x128xf32, #map>
+  return %1 : memref<1x128xf32, #map>
+}

diff  --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index acae8190f00ed..bd60cdfa78cce 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_library(MLIRTestTransforms
   TestDataLayoutQuery.cpp
   TestDominance.cpp
   TestDynamicPipeline.cpp
+  TestComposeSubView.cpp
   TestLoopFusion.cpp
   TestGpuMemoryPromotion.cpp
   TestGpuParallelLoopMapping.cpp

diff  --git a/mlir/test/lib/Transforms/TestComposeSubView.cpp b/mlir/test/lib/Transforms/TestComposeSubView.cpp
new file mode 100644
index 0000000000000..770e436c5addd
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestComposeSubView.cpp
@@ -0,0 +1,46 @@
+//===- TestComposeSubView.cpp - Test composed subviews --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to test the composed subview patterns.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/StandardOps/Transforms/ComposeSubView.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+struct TestComposeSubViewPass
+    : public PassWrapper<TestComposeSubViewPass, FunctionPass> {
+  void runOnFunction() override;
+  void getDependentDialects(DialectRegistry &registry) const override;
+};
+
+void TestComposeSubViewPass::getDependentDialects(
+    DialectRegistry &registry) const {
+  registry.insert<AffineDialect>();
+}
+
+void TestComposeSubViewPass::runOnFunction() {
+  OwningRewritePatternList patterns(&getContext());
+  populateComposeSubViewPatterns(patterns, &getContext());
+  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+}
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestComposeSubView() {
+  PassRegistration<TestComposeSubViewPass> pass(
+      "test-compose-subview", "Test combining composed subviews");
+}
+} // namespace test
+} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 428b3d5063174..2bef89ea7dda7 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -72,6 +72,7 @@ void registerTestDialect(DialectRegistry &);
 void registerTestDominancePass();
 void registerTestDynamicPipelinePass();
 void registerTestExpandTanhPass();
+void registerTestComposeSubView();
 void registerTestGpuParallelLoopMappingPass();
 void registerTestIRVisitorsPass();
 void registerTestInterfaces();
@@ -148,6 +149,7 @@ void registerTestPasses() {
   test::registerTestDominancePass();
   test::registerTestDynamicPipelinePass();
   test::registerTestExpandTanhPass();
+  test::registerTestComposeSubView();
   test::registerTestGpuParallelLoopMappingPass();
   test::registerTestIRVisitorsPass();
   test::registerTestInterfaces();


        


More information about the Mlir-commits mailing list