[Mlir-commits] [mlir] 6510fa9 - [mlir][memref] Add ValueBoundsOpInterface impls
Matthias Springer
llvmlistbot at llvm.org
Wed Apr 5 18:40:30 PDT 2023
Author: Matthias Springer
Date: 2023-04-06T10:35:52+09:00
New Revision: 6510fa90a0c12c18f39601c6f4f70bc7e916fe29
URL: https://github.com/llvm/llvm-project/commit/6510fa90a0c12c18f39601c6f4f70bc7e916fe29
DIFF: https://github.com/llvm/llvm-project/commit/6510fa90a0c12c18f39601c6f4f70bc7e916fe29.diff
LOG: [mlir][memref] Add ValueBoundsOpInterface impls
Differential Revision: https://reviews.llvm.org/D145695
Added:
mlir/include/mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h
mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir
Modified:
mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/include/mlir/InitAllDialects.h
mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index b37054b1c5e91..82f5ed96bfb96 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -100,6 +100,20 @@ class AllocLikeOp<string mnemonic,
static StringRef getAlignmentAttrStrName() { return "alignment"; }
MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
+
+ SmallVector<OpFoldResult> getMixedSizes() {
+ SmallVector<OpFoldResult> result;
+ unsigned ctr = 0;
+ OpBuilder b(getContext());
+ for (int64_t i = 0, e = getType().getRank(); i < e; ++i) {
+ if (getType().isDynamicDim(i)) {
+ result.push_back(getDynamicSizes()[ctr++]);
+ } else {
+ result.push_back(b.getIndexAttr(getType().getShape()[i]));
+ }
+ }
+ return result;
+ }
}];
let assemblyFormat = [{
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h b/mlir/include/mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h
new file mode 100644
index 0000000000000..eec43b7609c0d
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h
@@ -0,0 +1,20 @@
+//===- ValueBoundsOpInterfaceImpl.h - Impl. of ValueBoundsOpInterface -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MEMREF_IR_VALUEBOUNDSOPINTERFACEIMPL_H
+#define MLIR_DIALECT_MEMREF_IR_VALUEBOUNDSOPINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace memref {
+void registerValueBoundsOpInterfaceExternalModels(DialectRegistry ®istry);
+} // namespace memref
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MEMREF_IR_VALUEBOUNDSOPINTERFACEIMPL_H
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 1a4e0f948268f..f947655ce48ba 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -46,6 +46,7 @@
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
#include "mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
@@ -139,6 +140,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
linalg::registerTilingInterfaceExternalModels(registry);
memref::registerBufferizableOpInterfaceExternalModels(registry);
memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
+ memref::registerValueBoundsOpInterfaceExternalModels(registry);
scf::registerBufferizableOpInterfaceExternalModels(registry);
shape::registerBufferizableOpInterfaceExternalModels(registry);
sparse_tensor::registerBufferizableOpInterfaceExternalModels(registry);
diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
index f9228380c4f25..3aedd3783fa8f 100644
--- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRMemRefDialect
MemRefDialect.cpp
MemRefOps.cpp
+ ValueBoundsOpInterfaceImpl.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRefDialect
@@ -21,5 +22,6 @@ add_mlir_dialect_library(MLIRMemRefDialect
MLIRIR
MLIRShapedOpInterfaces
MLIRSideEffectInterfaces
+ MLIRValueBoundsOpInterface
MLIRViewLikeInterface
)
diff --git a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
new file mode 100644
index 0000000000000..ca63fb3d0de6a
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -0,0 +1,129 @@
+//===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h"
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
+
+using namespace mlir;
+
+namespace mlir {
+namespace memref {
+namespace {
+
+template <typename OpTy>
+struct AllocOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<AllocOpInterface<OpTy>,
+ OpTy> {
+ void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+ ValueBoundsConstraintSet &cstr) const {
+ auto allocOp = cast<OpTy>(op);
+ assert(value == allocOp.getResult() && "invalid value");
+
+ cstr.bound(value)[dim] == allocOp.getMixedSizes()[dim];
+ }
+};
+
+struct CastOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<CastOpInterface, CastOp> {
+ void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+ ValueBoundsConstraintSet &cstr) const {
+ auto castOp = cast<CastOp>(op);
+ assert(value == castOp.getResult() && "invalid value");
+
+ if (castOp.getResult().getType().isa<MemRefType>() &&
+ castOp.getSource().getType().isa<MemRefType>()) {
+ cstr.bound(value)[dim] == cstr.getExpr(castOp.getSource(), dim);
+ }
+ }
+};
+
+struct DimOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<DimOpInterface, DimOp> {
+ void populateBoundsForIndexValue(Operation *op, Value value,
+ ValueBoundsConstraintSet &cstr) const {
+ auto dimOp = cast<DimOp>(op);
+ assert(value == dimOp.getResult() && "invalid value");
+
+ auto constIndex = dimOp.getConstantIndex();
+ if (!constIndex.has_value())
+ return;
+ cstr.bound(value) == cstr.getExpr(dimOp.getSource(), *constIndex);
+ }
+};
+
+struct GetGlobalOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<GetGlobalOpInterface,
+ GetGlobalOp> {
+ void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+ ValueBoundsConstraintSet &cstr) const {
+ auto getGlobalOp = cast<GetGlobalOp>(op);
+ assert(value == getGlobalOp.getResult() && "invalid value");
+
+ auto type = getGlobalOp.getType();
+ assert(!type.isDynamicDim(dim) && "expected static dim");
+ cstr.bound(value)[dim] == type.getDimSize(dim);
+ }
+};
+
+struct RankOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<RankOpInterface, RankOp> {
+ void populateBoundsForIndexValue(Operation *op, Value value,
+ ValueBoundsConstraintSet &cstr) const {
+ auto rankOp = cast<RankOp>(op);
+ assert(value == rankOp.getResult() && "invalid value");
+
+ auto memrefType = rankOp.getMemref().getType().dyn_cast<MemRefType>();
+ if (!memrefType)
+ return;
+ cstr.bound(value) == memrefType.getRank();
+ }
+};
+
+struct SubViewOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<SubViewOpInterface,
+ SubViewOp> {
+ void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+ ValueBoundsConstraintSet &cstr) const {
+ auto subViewOp = cast<SubViewOp>(op);
+ assert(value == subViewOp.getResult() && "invalid value");
+
+ llvm::SmallBitVector dropped = subViewOp.getDroppedDims();
+ int64_t ctr = -1;
+ for (int64_t i = 0, e = subViewOp.getMixedSizes().size(); i < e; ++i) {
+ // Skip over rank-reduced dimensions.
+ if (!dropped.test(i))
+ ++ctr;
+ if (ctr == dim) {
+ cstr.bound(value)[dim] == subViewOp.getMixedSizes()[i];
+ return;
+ }
+ }
+ llvm_unreachable("could not find non-rank-reduced dim");
+ }
+};
+
+} // namespace
+} // namespace memref
+} // namespace mlir
+
+void mlir::memref::registerValueBoundsOpInterfaceExternalModels(
+ DialectRegistry ®istry) {
+ registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
+ memref::AllocOp::attachInterface<memref::AllocOpInterface<memref::AllocOp>>(
+ *ctx);
+ memref::AllocaOp::attachInterface<
+ memref::AllocOpInterface<memref::AllocaOp>>(*ctx);
+ memref::CastOp::attachInterface<memref::CastOpInterface>(*ctx);
+ memref::DimOp::attachInterface<memref::DimOpInterface>(*ctx);
+ memref::GetGlobalOp::attachInterface<memref::GetGlobalOpInterface>(*ctx);
+ memref::RankOp::attachInterface<memref::RankOpInterface>(*ctx);
+ memref::SubViewOp::attachInterface<memref::SubViewOpInterface>(*ctx);
+ });
+}
diff --git a/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir
new file mode 100644
index 0000000000000..0e0f216b05d48
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir
@@ -0,0 +1,86 @@
+// RUN: mlir-opt %s -test-affine-reify-value-bounds -verify-diagnostics \
+// RUN: -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func @memref_alloc(
+// CHECK-SAME: %[[sz:.*]]: index
+// CHECK: %[[c6:.*]] = arith.constant 6 : index
+// CHECK: return %[[c6]], %[[sz]]
+func.func @memref_alloc(%sz: index) -> (index, index) {
+ %0 = memref.alloc(%sz) : memref<6x?xf32>
+ %1 = "test.reify_bound"(%0) {dim = 0} : (memref<6x?xf32>) -> (index)
+ %2 = "test.reify_bound"(%0) {dim = 1} : (memref<6x?xf32>) -> (index)
+ return %1, %2 : index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_alloca(
+// CHECK-SAME: %[[sz:.*]]: index
+// CHECK: %[[c6:.*]] = arith.constant 6 : index
+// CHECK: return %[[c6]], %[[sz]]
+func.func @memref_alloca(%sz: index) -> (index, index) {
+ %0 = memref.alloca(%sz) : memref<6x?xf32>
+ %1 = "test.reify_bound"(%0) {dim = 0} : (memref<6x?xf32>) -> (index)
+ %2 = "test.reify_bound"(%0) {dim = 1} : (memref<6x?xf32>) -> (index)
+ return %1, %2 : index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_cast(
+// CHECK: %[[c10:.*]] = arith.constant 10 : index
+// CHECK: return %[[c10]]
+func.func @memref_cast(%m: memref<10xf32>) -> index {
+ %0 = memref.cast %m : memref<10xf32> to memref<?xf32>
+ %1 = "test.reify_bound"(%0) {dim = 0} : (memref<?xf32>) -> (index)
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_dim(
+// CHECK-SAME: %[[m:.*]]: memref<?xf32>
+// CHECK: %[[dim:.*]] = memref.dim %[[m]]
+// CHECK: %[[dim:.*]] = memref.dim %[[m]]
+// CHECK: return %[[dim]]
+func.func @memref_dim(%m: memref<?xf32>) -> index {
+ %c0 = arith.constant 0 : index
+ %0 = memref.dim %m, %c0 : memref<?xf32>
+ %1 = "test.reify_bound"(%0) : (index) -> (index)
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_get_global(
+// CHECK: %[[c4:.*]] = arith.constant 4 : index
+// CHECK: return %[[c4]]
+memref.global "private" @gv0 : memref<4xf32> = dense<[0.0, 1.0, 2.0, 3.0]>
+func.func @memref_get_global() -> index {
+ %0 = memref.get_global @gv0 : memref<4xf32>
+ %1 = "test.reify_bound"(%0) {dim = 0} : (memref<4xf32>) -> (index)
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_rank(
+// CHECK-SAME: %[[t:.*]]: memref<5xf32>
+// CHECK: %[[c1:.*]] = arith.constant 1 : index
+// CHECK: return %[[c1]]
+func.func @memref_rank(%m: memref<5xf32>) -> index {
+ %0 = memref.rank %m : memref<5xf32>
+ %1 = "test.reify_bound"(%0) : (index) -> (index)
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_subview(
+// CHECK-SAME: %[[m:.*]]: memref<?xf32>, %[[sz:.*]]: index
+// CHECK: return %[[sz]]
+func.func @memref_subview(%m: memref<?xf32>, %sz: index) -> index {
+ %0 = memref.subview %m[2][%sz][1] : memref<?xf32> to memref<?xf32, strided<[1], offset: 2>>
+ %1 = "test.reify_bound"(%0) {dim = 0} : (memref<?xf32, strided<[1], offset: 2>>) -> (index)
+ return %1 : index
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 2ee2ecd318846..19e98a9453bae 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -10257,6 +10257,7 @@ cc_library(
),
hdrs = [
"include/mlir/Dialect/MemRef/IR/MemRef.h",
+ "include/mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h",
"include/mlir/Dialect/MemRef/Utils/MemRefUtils.h",
],
includes = ["include"],
@@ -10271,6 +10272,7 @@ cc_library(
":MemRefBaseIncGen",
":MemRefOpsIncGen",
":ShapedOpInterfaces",
+ ":ValueBoundsOpInterface",
":ViewLikeInterface",
"//llvm:Support",
"//llvm:TargetParser",
More information about the Mlir-commits
mailing list