[Mlir-commits] [mlir] 289cfe9 - [mlir][linalg] ValueBoundsOpInterface: Add support for linalg.index

Matthias Springer llvmlistbot at llvm.org
Tue Apr 18 23:58:41 PDT 2023


Author: Matthias Springer
Date: 2023-04-19T15:51:37+09:00
New Revision: 289cfe9ccdcb04604580ae866533d7b17654ab93

URL: https://github.com/llvm/llvm-project/commit/289cfe9ccdcb04604580ae866533d7b17654ab93
DIFF: https://github.com/llvm/llvm-project/commit/289cfe9ccdcb04604580ae866533d7b17654ab93.diff

LOG: [mlir][linalg] ValueBoundsOpInterface: Add support for linalg.index

Differential Revision: https://reviews.llvm.org/D148598

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp
    mlir/test/Dialect/Linalg/value-bounds-op-interface-impl.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp
index 389cac41bafb2..55d09c421e31b 100644
--- a/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -17,6 +17,36 @@ namespace mlir {
 namespace linalg {
 namespace {
 
+struct IndexOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<IndexOpInterface, IndexOp> {
+  void populateBoundsForIndexValue(Operation *op, Value value,
+                                   ValueBoundsConstraintSet &cstr) const {
+    auto indexOp = cast<IndexOp>(op);
+    auto linalgOp = indexOp->getParentOfType<LinalgOp>();
+    assert(value == indexOp.getResult() && "invalid value");
+
+    // index >= 0
+    cstr.bound(value) >= 0;
+
+    // index < dim size
+    int64_t flatDimPos = linalgOp.getShapesToLoopsMap()
+                             .getResult(indexOp.getDim())
+                             .cast<AffineDimExpr>()
+                             .getPosition();
+    // Find the `flatDimPos`-th operand dimension.
+    int64_t flatDimCtr = 0;
+    for (Value operand : linalgOp->getOperands()) {
+      assert(flatDimPos >= flatDimCtr && "invalid pos");
+      auto shapedType = operand.getType().cast<ShapedType>();
+      if (flatDimPos < flatDimCtr + shapedType.getRank()) {
+        cstr.bound(value) < cstr.getExpr(operand, flatDimPos - flatDimCtr);
+        break;
+      }
+      flatDimCtr += shapedType.getRank();
+    }
+  }
+};
+
 /// Helper structure that iterates over all LinalgOps in `OpTys` and registers
 /// the `ValueBoundsOpInterface` with each of them.
 template <typename... Ops> struct LinalgValueBoundsOpInterfaceHelper {
@@ -34,6 +64,8 @@ template <typename... Ops> struct LinalgValueBoundsOpInterfaceHelper {
 void mlir::linalg::registerValueBoundsOpInterfaceExternalModels(
     DialectRegistry &registry) {
   registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
+    IndexOp::attachInterface<IndexOpInterface>(*ctx);
+
     // Register all Linalg structured ops.
     LinalgValueBoundsOpInterfaceHelper<
 #define GET_OP_LIST

diff  --git a/mlir/test/Dialect/Linalg/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Linalg/value-bounds-op-interface-impl.mlir
index 537bc98ee544c..189c8e649ba5e 100644
--- a/mlir/test/Dialect/Linalg/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Linalg/value-bounds-op-interface-impl.mlir
@@ -11,3 +11,54 @@ func.func @linalg_fill(%t: tensor<?xf32>, %f: f32) -> index {
   %1 = "test.reify_bound"(%0) {dim = 0} : (tensor<?xf32>) -> (index)
   return %1 : index
 }
+
+// -----
+
+#accesses = [
+  affine_map<(i, j, k) -> (j, i)>,
+  affine_map<(i, j, k) -> (i, k, i + j)>
+]
+
+#trait = {
+  indexing_maps = #accesses,
+  iterator_types = ["parallel", "parallel", "parallel"]
+}
+
+// CHECK-LABEL: func @linalg_index(
+//  CHECK-SAME:     %[[arg0:.*]]: memref<?x?xf32>
+func.func @linalg_index(%arg0: memref<?x?xf32>,
+                        %arg1: memref<?x5x?xf32>) {
+  linalg.generic #trait
+                 ins(%arg0 : memref<?x?xf32>)
+                 outs(%arg1 : memref<?x5x?xf32>)
+  {
+    ^bb(%a: f32, %b: f32):
+      // CHECK: %[[c1:.*]] = arith.constant 1 : index
+      // CHECK: %[[ub_0:.*]] = memref.dim %[[arg0]], %[[c1]]
+      // CHECK: "test.some_use"(%[[ub_0]])
+      %0 = linalg.index 0 : index
+      %ub_0 = "test.reify_bound"(%0) {type = "UB"} : (index) -> (index)
+      "test.some_use"(%ub_0) : (index) -> ()
+
+      // CHECK: %[[c0:.*]] = arith.constant 0 : index
+      // CHECK: "test.some_use"(%[[c0]])
+      %lb_0 = "test.reify_bound"(%0) {type = "LB"} : (index) -> (index)
+      "test.some_use"(%lb_0) : (index) -> ()
+
+      // CHECK: %[[c0:.*]] = arith.constant 0 : index
+      // CHECK: %[[ub_1:.*]] = memref.dim %[[arg0]], %[[c0]]
+      // CHECK: "test.some_use"(%[[ub_1]])
+      %1 = linalg.index 1 : index
+      %ub_1 = "test.reify_bound"(%1) {type = "UB"} : (index) -> (index)
+      "test.some_use"(%ub_1) : (index) -> ()
+
+      // CHECK: %[[ub_2:.*]] = arith.constant 5 : index
+      // CHECK: "test.some_use"(%[[ub_2]])
+      %2 = linalg.index 2 : index
+      %ub_2 = "test.reify_bound"(%2) {type = "UB"} : (index) -> (index)
+      "test.some_use"(%ub_2) : (index) -> ()
+
+      linalg.yield %b : f32
+  }
+  return
+}


        


More information about the Mlir-commits mailing list