[Mlir-commits] [mlir] 3ba1b1c - Add a pattern to combine composed subview ops
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Apr 1 10:57:15 PDT 2021
Author: Aden Grue
Date: 2021-04-01T10:56:57-07:00
New Revision: 3ba1b1cd201dbf2d5f9da4d3a018091c3e3a2d78
URL: https://github.com/llvm/llvm-project/commit/3ba1b1cd201dbf2d5f9da4d3a018091c3e3a2d78
DIFF: https://github.com/llvm/llvm-project/commit/3ba1b1cd201dbf2d5f9da4d3a018091c3e3a2d78.diff
LOG: Add a pattern to combine composed subview ops
Differential Revision: https://reviews.llvm.org/D99229
Added:
mlir/include/mlir/Dialect/StandardOps/Transforms/ComposeSubView.h
mlir/lib/Dialect/StandardOps/Transforms/ComposeSubView.cpp
mlir/test/Transforms/compose-subview.mlir
mlir/test/lib/Transforms/TestComposeSubView.cpp
Modified:
mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
mlir/test/lib/Transforms/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/ComposeSubView.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/ComposeSubView.h
new file mode 100644
index 0000000000000..7a5ae3e8417b7
--- /dev/null
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/ComposeSubView.h
@@ -0,0 +1,28 @@
+//===- ComposeSubView.h - Combining composed subview ops --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Patterns for combining composed subview ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_STANDARDOPS_TRANSFORMS_COMPOSESUBVIEW_H_
+#define MLIR_DIALECT_STANDARDOPS_TRANSFORMS_COMPOSESUBVIEW_H_
+
+namespace mlir {
+
+// Forward declarations.
+class MLIRContext;
+class RewritePatternSet;
+using OwningRewritePatternList = RewritePatternSet;
+
+void populateComposeSubViewPatterns(OwningRewritePatternList &patterns,
+ MLIRContext *context);
+
+} // namespace mlir
+
+#endif // MLIR_DIALECT_STANDARDOPS_TRANSFORMS_COMPOSESUBVIEW_H_
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
index b01ff954d7f3c..2fa0b96bed7ab 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRStandardOpsTransforms
Bufferize.cpp
+ ComposeSubView.cpp
DecomposeCallGraphTypes.cpp
ExpandOps.cpp
FuncBufferize.cpp
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/ComposeSubView.cpp b/mlir/lib/Dialect/StandardOps/Transforms/ComposeSubView.cpp
new file mode 100644
index 0000000000000..cabaa614fa2a9
--- /dev/null
+++ b/mlir/lib/Dialect/StandardOps/Transforms/ComposeSubView.cpp
@@ -0,0 +1,136 @@
+//===- ComposeSubView.cpp - Combining composed subview ops ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains patterns for combining composed subview ops (i.e. subview
+// of a subview becomes a single subview).
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/StandardOps/Transforms/ComposeSubView.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+
+namespace {
+
+// Replaces a subview of a subview with a single subview. Only supports subview
+// ops with static sizes and static strides of 1 (both static and dynamic
+// offsets are supported).
+struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(memref::SubViewOp op,
+ PatternRewriter &rewriter) const override {
+ // 'op' is the 'SubViewOp' we're rewriting. 'sourceOp' is the op that
+ // produces the input of the op we're rewriting (for 'SubViewOp' the input
+ // is called the "source" value). We can only combine them if both 'op' and
+ // 'sourceOp' are 'SubViewOp'.
+ auto sourceOp = op.source().getDefiningOp<memref::SubViewOp>();
+ if (!sourceOp)
+ return failure();
+
+ // A 'SubViewOp' can be "rank-reducing" by eliminating dimensions of the
+ // output memref that are statically known to be equal to 1. We do not
+ // allow 'sourceOp' to be a rank-reducing subview because then our two
+ // 'SubViewOp's would have
diff erent numbers of offset/size/stride
+ // parameters (just
diff icult to deal with, not impossible if we end up
+ // needing it).
+ if (sourceOp.getSourceType().getRank() != sourceOp.getType().getRank()) {
+ return failure();
+ }
+
+ // Offsets, sizes and strides OpFoldResult for the combined 'SubViewOp'.
+ SmallVector<OpFoldResult> offsets, sizes, strides;
+
+ // Because we only support input strides of 1, the output stride is also
+ // always 1.
+ if (llvm::all_of(strides, [](OpFoldResult &valueOrAttr) {
+ Attribute attr = valueOrAttr.dyn_cast<Attribute>();
+ return attr && attr.cast<IntegerAttr>().getInt() == 1;
+ })) {
+ strides = SmallVector<OpFoldResult>(sourceOp.getMixedStrides().size(),
+ rewriter.getI64IntegerAttr(1));
+ } else {
+ return failure();
+ }
+
+ // The rules for calculating the new offsets and sizes are:
+ // * Multiple subview offsets for a given dimension compose additively.
+ // ("Offset by m" followed by "Offset by n" == "Offset by m + n")
+ // * Multiple sizes for a given dimension compose by taking the size of the
+ // final subview and ignoring the rest. ("Take m values" followed by "Take
+ // n values" == "Take n values") This size must also be the smallest one
+ // by definition (a subview needs to be the same size as or smaller than
+ // its source along each dimension; presumably subviews that are larger
+ // than their sources are disallowed by validation).
+ for (auto it : llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(),
+ op.getMixedSizes())) {
+ auto opOffset = std::get<0>(it);
+ auto sourceOffset = std::get<1>(it);
+ auto opSize = std::get<2>(it);
+
+ // We only support static sizes.
+ if (opSize.is<Value>()) {
+ return failure();
+ }
+
+ sizes.push_back(opSize);
+ Attribute opOffsetAttr = opOffset.dyn_cast<Attribute>(),
+ sourceOffsetAttr = sourceOffset.dyn_cast<Attribute>();
+
+ if (opOffsetAttr && sourceOffsetAttr) {
+ // If both offsets are static we can simply calculate the combined
+ // offset statically.
+ offsets.push_back(rewriter.getI64IntegerAttr(
+ opOffsetAttr.cast<IntegerAttr>().getInt() +
+ sourceOffsetAttr.cast<IntegerAttr>().getInt()));
+ } else {
+ // When either offset is dynamic, we must emit an additional affine
+ // transformation to add the two offsets together dynamically.
+ AffineExpr expr = rewriter.getAffineConstantExpr(0);
+ SmallVector<Value> affineApplyOperands;
+ for (auto valueOrAttr : {opOffset, sourceOffset}) {
+ if (auto attr = valueOrAttr.dyn_cast<Attribute>()) {
+ expr = expr + attr.cast<IntegerAttr>().getInt();
+ } else {
+ expr =
+ expr + rewriter.getAffineSymbolExpr(affineApplyOperands.size());
+ affineApplyOperands.push_back(valueOrAttr.get<Value>());
+ }
+ }
+
+ AffineMap map = AffineMap::get(0, affineApplyOperands.size(), expr);
+ Value result = rewriter.create<AffineApplyOp>(op.getLoc(), map,
+ affineApplyOperands);
+ offsets.push_back(result);
+ }
+ }
+
+ // This replaces 'op' but leaves 'sourceOp' alone; if it no longer has any
+ // uses it can be removed by a (separate) dead code elimination pass.
+ rewriter.replaceOpWithNewOp<memref::SubViewOp>(op, sourceOp.source(),
+ offsets, sizes, strides);
+ return success();
+ }
+};
+
+} // namespace
+
+void populateComposeSubViewPatterns(OwningRewritePatternList &patterns,
+ MLIRContext *context) {
+ patterns.insert<ComposeSubViewOpPattern>(context);
+}
+
+} // namespace mlir
diff --git a/mlir/test/Transforms/compose-subview.mlir b/mlir/test/Transforms/compose-subview.mlir
new file mode 100644
index 0000000000000..9081ba232d175
--- /dev/null
+++ b/mlir/test/Transforms/compose-subview.mlir
@@ -0,0 +1,62 @@
+// RUN: mlir-opt -allow-unregistered-dialect %s -test-compose-subview -split-input-file | FileCheck %s
+
+// CHECK: [[MAP:#.*]] = affine_map<(d0, d1) -> (d0 * 1024 + d1 + 3456)
+#map0 = affine_map<(d0, d1) -> (d0 * 1024 + d1 + 2304)>
+#map1 = affine_map<(d0, d1) -> (d0 * 1024 + d1 + 3456)>
+
+func @main(%input: memref<4x1024xf32>) -> memref<1x128xf32, #map1> {
+ // CHECK: subview %arg0[3, 384] [1, 128] [1, 1]
+ // CHECK-SAME: memref<4x1024xf32> to memref<1x128xf32, [[MAP]]>
+ %0 = memref.subview %input[2, 256] [2, 256] [1, 1] : memref<4x1024xf32> to memref<2x256xf32, #map0>
+ %1 = memref.subview %0[1, 128] [1, 128] [1, 1] : memref<2x256xf32, #map0> to memref<1x128xf32, #map1>
+ return %1 : memref<1x128xf32, #map1>
+}
+
+// -----
+
+// CHECK: [[MAP:#.*]] = affine_map<(d0, d1) -> (d0 * 1024 + d1 + 3745)
+#map0 = affine_map<(d0, d1) -> (d0 * 1024 + d1 + 1536)>
+#map1 = affine_map<(d0, d1) -> (d0 * 1024 + d1 + 2688)>
+#map2 = affine_map<(d0, d1) -> (d0 * 1024 + d1 + 3745)>
+
+func @main(%input: memref<4x1024xf32>) -> memref<1x10xf32, #map2> {
+ // CHECK: subview %arg0[3, 673] [1, 10] [1, 1]
+ // CHECK-SAME: memref<4x1024xf32> to memref<1x10xf32, [[MAP]]>
+ %0 = memref.subview %input[1, 512] [3, 256] [1, 1] : memref<4x1024xf32> to memref<3x256xf32, #map0>
+ %1 = memref.subview %0[1, 128] [2, 128] [1, 1] : memref<3x256xf32, #map0> to memref<2x128xf32, #map1>
+ %2 = memref.subview %1[1, 33] [1, 10] [1, 1] : memref<2x128xf32, #map1> to memref<1x10xf32, #map2>
+ return %2 : memref<1x10xf32, #map2>
+}
+
+// -----
+
+// CHECK: [[MAP:#.*]] = affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)
+#map = affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>
+
+func @main(%input: memref<4x1024xf32>) -> memref<1x128xf32, #map> {
+ // CHECK: [[CST_3:%.*]] = constant 3 : index
+ %cst_1 = constant 1 : index
+ %cst_2 = constant 2 : index
+ // CHECK: subview %arg0{{\[}}[[CST_3]], 384] [1, 128] [1, 1]
+ // CHECK-SAME: memref<4x1024xf32> to memref<1x128xf32, [[MAP]]>
+ %0 = memref.subview %input[%cst_2, 256] [2, 256] [1, 1] : memref<4x1024xf32> to memref<2x256xf32, #map>
+ %1 = memref.subview %0[%cst_1, 128] [1, 128] [1, 1] : memref<2x256xf32, #map> to memref<1x128xf32, #map>
+ return %1 : memref<1x128xf32, #map>
+}
+
+// -----
+
+// CHECK: [[MAP:#.*]] = affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)
+#map = affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>
+
+func @main(%input: memref<4x1024xf32>) -> memref<1x128xf32, #map> {
+ // CHECK: [[CST_3:%.*]] = constant 3 : index
+ %cst_2 = constant 2 : index
+ // CHECK: [[CST_384:%.*]] = constant 384 : index
+ %cst_128 = constant 128 : index
+ // CHECK: subview %arg0{{\[}}[[CST_3]], [[CST_384]]] [1, 128] [1, 1]
+ // CHECK-SAME: memref<4x1024xf32> to memref<1x128xf32, [[MAP]]>
+ %0 = memref.subview %input[%cst_2, 256] [2, 256] [1, 1] : memref<4x1024xf32> to memref<2x256xf32, #map>
+ %1 = memref.subview %0[1, %cst_128] [1, 128] [1, 1] : memref<2x256xf32, #map> to memref<1x128xf32, #map>
+ return %1 : memref<1x128xf32, #map>
+}
diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index acae8190f00ed..bd60cdfa78cce 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_library(MLIRTestTransforms
TestDataLayoutQuery.cpp
TestDominance.cpp
TestDynamicPipeline.cpp
+ TestComposeSubView.cpp
TestLoopFusion.cpp
TestGpuMemoryPromotion.cpp
TestGpuParallelLoopMapping.cpp
diff --git a/mlir/test/lib/Transforms/TestComposeSubView.cpp b/mlir/test/lib/Transforms/TestComposeSubView.cpp
new file mode 100644
index 0000000000000..770e436c5addd
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestComposeSubView.cpp
@@ -0,0 +1,46 @@
+//===- TestComposeSubView.cpp - Test composed subviews --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to test the composed subview patterns.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/StandardOps/Transforms/ComposeSubView.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+struct TestComposeSubViewPass
+ : public PassWrapper<TestComposeSubViewPass, FunctionPass> {
+ void runOnFunction() override;
+ void getDependentDialects(DialectRegistry ®istry) const override;
+};
+
+void TestComposeSubViewPass::getDependentDialects(
+ DialectRegistry ®istry) const {
+ registry.insert<AffineDialect>();
+}
+
+void TestComposeSubViewPass::runOnFunction() {
+ OwningRewritePatternList patterns(&getContext());
+ populateComposeSubViewPatterns(patterns, &getContext());
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+}
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestComposeSubView() {
+ PassRegistration<TestComposeSubViewPass> pass(
+ "test-compose-subview", "Test combining composed subviews");
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 428b3d5063174..2bef89ea7dda7 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -72,6 +72,7 @@ void registerTestDialect(DialectRegistry &);
void registerTestDominancePass();
void registerTestDynamicPipelinePass();
void registerTestExpandTanhPass();
+void registerTestComposeSubView();
void registerTestGpuParallelLoopMappingPass();
void registerTestIRVisitorsPass();
void registerTestInterfaces();
@@ -148,6 +149,7 @@ void registerTestPasses() {
test::registerTestDominancePass();
test::registerTestDynamicPipelinePass();
test::registerTestExpandTanhPass();
+ test::registerTestComposeSubView();
test::registerTestGpuParallelLoopMappingPass();
test::registerTestIRVisitorsPass();
test::registerTestInterfaces();
More information about the Mlir-commits
mailing list