[Mlir-commits] [mlir] 6510fa9 - [mlir][memref] Add ValueBoundsOpInterface impls

Matthias Springer llvmlistbot at llvm.org
Wed Apr 5 18:40:30 PDT 2023


Author: Matthias Springer
Date: 2023-04-06T10:35:52+09:00
New Revision: 6510fa90a0c12c18f39601c6f4f70bc7e916fe29

URL: https://github.com/llvm/llvm-project/commit/6510fa90a0c12c18f39601c6f4f70bc7e916fe29
DIFF: https://github.com/llvm/llvm-project/commit/6510fa90a0c12c18f39601c6f4f70bc7e916fe29.diff

LOG: [mlir][memref] Add ValueBoundsOpInterface impls

Differential Revision: https://reviews.llvm.org/D145695

Added: 
    mlir/include/mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h
    mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
    mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir

Modified: 
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/include/mlir/InitAllDialects.h
    mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index b37054b1c5e91..82f5ed96bfb96 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -100,6 +100,20 @@ class AllocLikeOp<string mnemonic,
     static StringRef getAlignmentAttrStrName() { return "alignment"; }
 
     MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
+
+    SmallVector<OpFoldResult> getMixedSizes() {
+      SmallVector<OpFoldResult> result;
+      unsigned ctr = 0;
+      OpBuilder b(getContext());
+      for (int64_t i = 0, e = getType().getRank(); i < e; ++i) {
+        if (getType().isDynamicDim(i)) {
+          result.push_back(getDynamicSizes()[ctr++]);
+        } else {
+          result.push_back(b.getIndexAttr(getType().getShape()[i]));
+        }
+      }
+      return result;
+    }
   }];
 
   let assemblyFormat = [{

diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h b/mlir/include/mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h
new file mode 100644
index 0000000000000..eec43b7609c0d
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h
@@ -0,0 +1,20 @@
+//===- ValueBoundsOpInterfaceImpl.h - Impl. of ValueBoundsOpInterface -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MEMREF_IR_VALUEBOUNDSOPINTERFACEIMPL_H
+#define MLIR_DIALECT_MEMREF_IR_VALUEBOUNDSOPINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace memref {
+void registerValueBoundsOpInterfaceExternalModels(DialectRegistry &registry);
+} // namespace memref
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MEMREF_IR_VALUEBOUNDSOPINTERFACEIMPL_H

diff  --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 1a4e0f948268f..f947655ce48ba 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -46,6 +46,7 @@
 #include "mlir/Dialect/MLProgram/IR/MLProgram.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h"
 #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
 #include "mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
@@ -139,6 +140,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
   linalg::registerTilingInterfaceExternalModels(registry);
   memref::registerBufferizableOpInterfaceExternalModels(registry);
   memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
+  memref::registerValueBoundsOpInterfaceExternalModels(registry);
   scf::registerBufferizableOpInterfaceExternalModels(registry);
   shape::registerBufferizableOpInterfaceExternalModels(registry);
   sparse_tensor::registerBufferizableOpInterfaceExternalModels(registry);

diff  --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
index f9228380c4f25..3aedd3783fa8f 100644
--- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRMemRefDialect
   MemRefDialect.cpp
   MemRefOps.cpp
+  ValueBoundsOpInterfaceImpl.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRefDialect
@@ -21,5 +22,6 @@ add_mlir_dialect_library(MLIRMemRefDialect
   MLIRIR
   MLIRShapedOpInterfaces
   MLIRSideEffectInterfaces
+  MLIRValueBoundsOpInterface
   MLIRViewLikeInterface
 )

diff  --git a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
new file mode 100644
index 0000000000000..ca63fb3d0de6a
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -0,0 +1,129 @@
+//===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h"
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
+
+using namespace mlir;
+
+namespace mlir {
+namespace memref {
+namespace {
+
+template <typename OpTy>
+struct AllocOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<AllocOpInterface<OpTy>,
+                                                   OpTy> {
+  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+                                       ValueBoundsConstraintSet &cstr) const {
+    auto allocOp = cast<OpTy>(op);
+    assert(value == allocOp.getResult() && "invalid value");
+
+    cstr.bound(value)[dim] == allocOp.getMixedSizes()[dim];
+  }
+};
+
+struct CastOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<CastOpInterface, CastOp> {
+  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+                                       ValueBoundsConstraintSet &cstr) const {
+    auto castOp = cast<CastOp>(op);
+    assert(value == castOp.getResult() && "invalid value");
+
+    if (castOp.getResult().getType().isa<MemRefType>() &&
+        castOp.getSource().getType().isa<MemRefType>()) {
+      cstr.bound(value)[dim] == cstr.getExpr(castOp.getSource(), dim);
+    }
+  }
+};
+
+struct DimOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<DimOpInterface, DimOp> {
+  void populateBoundsForIndexValue(Operation *op, Value value,
+                                   ValueBoundsConstraintSet &cstr) const {
+    auto dimOp = cast<DimOp>(op);
+    assert(value == dimOp.getResult() && "invalid value");
+
+    auto constIndex = dimOp.getConstantIndex();
+    if (!constIndex.has_value())
+      return;
+    cstr.bound(value) == cstr.getExpr(dimOp.getSource(), *constIndex);
+  }
+};
+
+struct GetGlobalOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<GetGlobalOpInterface,
+                                                   GetGlobalOp> {
+  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+                                       ValueBoundsConstraintSet &cstr) const {
+    auto getGlobalOp = cast<GetGlobalOp>(op);
+    assert(value == getGlobalOp.getResult() && "invalid value");
+
+    auto type = getGlobalOp.getType();
+    assert(!type.isDynamicDim(dim) && "expected static dim");
+    cstr.bound(value)[dim] == type.getDimSize(dim);
+  }
+};
+
+struct RankOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<RankOpInterface, RankOp> {
+  void populateBoundsForIndexValue(Operation *op, Value value,
+                                   ValueBoundsConstraintSet &cstr) const {
+    auto rankOp = cast<RankOp>(op);
+    assert(value == rankOp.getResult() && "invalid value");
+
+    auto memrefType = rankOp.getMemref().getType().dyn_cast<MemRefType>();
+    if (!memrefType)
+      return;
+    cstr.bound(value) == memrefType.getRank();
+  }
+};
+
+struct SubViewOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<SubViewOpInterface,
+                                                   SubViewOp> {
+  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+                                       ValueBoundsConstraintSet &cstr) const {
+    auto subViewOp = cast<SubViewOp>(op);
+    assert(value == subViewOp.getResult() && "invalid value");
+
+    llvm::SmallBitVector dropped = subViewOp.getDroppedDims();
+    int64_t ctr = -1;
+    for (int64_t i = 0, e = subViewOp.getMixedSizes().size(); i < e; ++i) {
+      // Skip over rank-reduced dimensions.
+      if (!dropped.test(i))
+        ++ctr;
+      if (ctr == dim) {
+        cstr.bound(value)[dim] == subViewOp.getMixedSizes()[i];
+        return;
+      }
+    }
+    llvm_unreachable("could not find non-rank-reduced dim");
+  }
+};
+
+} // namespace
+} // namespace memref
+} // namespace mlir
+
+void mlir::memref::registerValueBoundsOpInterfaceExternalModels(
+    DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
+    memref::AllocOp::attachInterface<memref::AllocOpInterface<memref::AllocOp>>(
+        *ctx);
+    memref::AllocaOp::attachInterface<
+        memref::AllocOpInterface<memref::AllocaOp>>(*ctx);
+    memref::CastOp::attachInterface<memref::CastOpInterface>(*ctx);
+    memref::DimOp::attachInterface<memref::DimOpInterface>(*ctx);
+    memref::GetGlobalOp::attachInterface<memref::GetGlobalOpInterface>(*ctx);
+    memref::RankOp::attachInterface<memref::RankOpInterface>(*ctx);
+    memref::SubViewOp::attachInterface<memref::SubViewOpInterface>(*ctx);
+  });
+}

diff  --git a/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir
new file mode 100644
index 0000000000000..0e0f216b05d48
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir
@@ -0,0 +1,86 @@
+// RUN: mlir-opt %s -test-affine-reify-value-bounds -verify-diagnostics \
+// RUN:     -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func @memref_alloc(
+//  CHECK-SAME:     %[[sz:.*]]: index
+//       CHECK:   %[[c6:.*]] = arith.constant 6 : index
+//       CHECK:   return %[[c6]], %[[sz]]
+func.func @memref_alloc(%sz: index) -> (index, index) {
+  %0 = memref.alloc(%sz) : memref<6x?xf32>
+  %1 = "test.reify_bound"(%0) {dim = 0} : (memref<6x?xf32>) -> (index)
+  %2 = "test.reify_bound"(%0) {dim = 1} : (memref<6x?xf32>) -> (index)
+  return %1, %2 : index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_alloca(
+//  CHECK-SAME:     %[[sz:.*]]: index
+//       CHECK:   %[[c6:.*]] = arith.constant 6 : index
+//       CHECK:   return %[[c6]], %[[sz]]
+func.func @memref_alloca(%sz: index) -> (index, index) {
+  %0 = memref.alloca(%sz) : memref<6x?xf32>
+  %1 = "test.reify_bound"(%0) {dim = 0} : (memref<6x?xf32>) -> (index)
+  %2 = "test.reify_bound"(%0) {dim = 1} : (memref<6x?xf32>) -> (index)
+  return %1, %2 : index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_cast(
+//       CHECK:   %[[c10:.*]] = arith.constant 10 : index
+//       CHECK:   return %[[c10]]
+func.func @memref_cast(%m: memref<10xf32>) -> index {
+  %0 = memref.cast %m : memref<10xf32> to memref<?xf32>
+  %1 = "test.reify_bound"(%0) {dim = 0} : (memref<?xf32>) -> (index)
+  return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_dim(
+//  CHECK-SAME:     %[[m:.*]]: memref<?xf32>
+//       CHECK:   %[[dim:.*]] = memref.dim %[[m]]
+//       CHECK:   %[[dim:.*]] = memref.dim %[[m]]
+//       CHECK:   return %[[dim]]
+func.func @memref_dim(%m: memref<?xf32>) -> index {
+  %c0 = arith.constant 0 : index
+  %0 = memref.dim %m, %c0 : memref<?xf32>
+  %1 = "test.reify_bound"(%0) : (index) -> (index)
+  return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_get_global(
+//       CHECK:   %[[c4:.*]] = arith.constant 4 : index
+//       CHECK:   return %[[c4]]
+memref.global "private" @gv0 : memref<4xf32> = dense<[0.0, 1.0, 2.0, 3.0]>
+func.func @memref_get_global() -> index {
+  %0 = memref.get_global @gv0 : memref<4xf32>
+  %1 = "test.reify_bound"(%0) {dim = 0} : (memref<4xf32>) -> (index)
+  return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_rank(
+//  CHECK-SAME:     %[[t:.*]]: memref<5xf32>
+//       CHECK:   %[[c1:.*]] = arith.constant 1 : index
+//       CHECK:   return %[[c1]]
+func.func @memref_rank(%m: memref<5xf32>) -> index {
+  %0 = memref.rank %m : memref<5xf32>
+  %1 = "test.reify_bound"(%0) : (index) -> (index)
+  return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_subview(
+//  CHECK-SAME:     %[[m:.*]]: memref<?xf32>, %[[sz:.*]]: index
+//       CHECK:   return %[[sz]]
+func.func @memref_subview(%m: memref<?xf32>, %sz: index) -> index {
+  %0 = memref.subview %m[2][%sz][1] : memref<?xf32> to memref<?xf32, strided<[1], offset: 2>>
+  %1 = "test.reify_bound"(%0) {dim = 0} : (memref<?xf32, strided<[1], offset: 2>>) -> (index)
+  return %1 : index
+}

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 2ee2ecd318846..19e98a9453bae 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -10257,6 +10257,7 @@ cc_library(
     ),
     hdrs = [
         "include/mlir/Dialect/MemRef/IR/MemRef.h",
+        "include/mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h",
         "include/mlir/Dialect/MemRef/Utils/MemRefUtils.h",
     ],
     includes = ["include"],
@@ -10271,6 +10272,7 @@ cc_library(
         ":MemRefBaseIncGen",
         ":MemRefOpsIncGen",
         ":ShapedOpInterfaces",
+        ":ValueBoundsOpInterface",
         ":ViewLikeInterface",
         "//llvm:Support",
         "//llvm:TargetParser",


        


More information about the Mlir-commits mailing list