[Mlir-commits] [llvm] [mlir] [mlir][Interfaces][NFC] Add TableGen test op for value bounds tests (PR #88717)

Matthias Springer llvmlistbot at llvm.org
Mon Apr 15 06:13:38 PDT 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/88717

>From bc3b2989bea3ad4e15c9891e9d35618bf491b54c Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 15 Apr 2024 11:50:20 +0000
Subject: [PATCH] [mlir][Interfaces][NFC] Add TableGen test op for value bounds
 tests

This commit is a code cleanup. It defines the test ops the are used for the `ValueBoundsOpInterface` tests in TableGen, along with proper verifiers.
---
 .../value-bounds-op-interface-impl.mlir       |   2 +-
 .../Affine/value-bounds-reification.mlir      |   6 +-
 .../value-bounds-op-interface-impl.mlir       |   4 +-
 .../Dialect/Vector/test-scalable-bounds.mlir  |  18 +-
 mlir/test/lib/Dialect/Affine/CMakeLists.txt   |   8 +
 .../Dialect/Affine/TestReifyValueBounds.cpp   | 305 ++++++------------
 mlir/test/lib/Dialect/Test/CMakeLists.txt     |   1 +
 mlir/test/lib/Dialect/Test/TestDialect.cpp    |  48 +++
 mlir/test/lib/Dialect/Test/TestDialect.h      |   1 +
 mlir/test/lib/Dialect/Test/TestOps.td         |  47 +++
 .../mlir/test/BUILD.bazel                     |   2 +
 11 files changed, 226 insertions(+), 216 deletions(-)

diff --git a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
index 10da91870f49d9..23c6872dcebe94 100644
--- a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
@@ -74,7 +74,7 @@ func.func @composed_affine_apply(%i1 : index) -> (index) {
   %i2 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16)>(%i1)
   %i3 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16 + 8)>(%i1)
   %s = affine.apply affine_map<()[s0, s1] -> (s0 - s1)>()[%i2, %i3]
-  %reified = "test.reify_constant_bound"(%s) {type = "EQ"} : (index) -> (index)
+  %reified = "test.reify_bound"(%s) {type = "EQ", constant} : (index) -> (index)
   return %reified : index
 }
 
diff --git a/mlir/test/Dialect/Affine/value-bounds-reification.mlir b/mlir/test/Dialect/Affine/value-bounds-reification.mlir
index 909c9098c51607..75622f59af83be 100644
--- a/mlir/test/Dialect/Affine/value-bounds-reification.mlir
+++ b/mlir/test/Dialect/Affine/value-bounds-reification.mlir
@@ -47,7 +47,7 @@ func.func @reify_slice_bound(%t: tensor<?x?xi32>, %idx: index, %ub: index, %f: f
     %bound = "test.reify_bound"(%filled) {dim = 1, type = "UB"} : (tensor<1x?xi32>) -> (index)
     "test.some_use"(%bound) : (index) -> ()
 
-    %bound_const = "test.reify_constant_bound"(%filled) {dim = 1, type = "UB"} : (tensor<1x?xi32>) -> (index)
+    %bound_const = "test.reify_bound"(%filled) {dim = 1, type = "UB", constant} : (tensor<1x?xi32>) -> (index)
     "test.some_use"(%bound_const) : (index) -> ()
   }
   return
@@ -93,7 +93,7 @@ func.func @reify_slice_bound2(%lb0: index, %ub0: index, %step0: index,
 
     // CHECK: %[[c129:.*]] = arith.constant 129 : index
     // CHECK: "test.some_use"(%[[c129]])
-    %lb1_ub_const = "test.reify_constant_bound"(%lb1) {type = "UB"} : (index) -> (index)
+    %lb1_ub_const = "test.reify_bound"(%lb1) {type = "UB", constant} : (index) -> (index)
     "test.some_use"(%lb1_ub_const) : (index) -> ()
 
     scf.for %iv1 = %lb1 to %ub1 step %c32 {
@@ -116,7 +116,7 @@ func.func @reify_slice_bound2(%lb0: index, %ub0: index, %step0: index,
 
         // CHECK: %[[c32:.*]] = arith.constant 32 : index
         // CHECK: "test.some_use"(%[[c32]])
-        %matmul_ub_const = "test.reify_constant_bound"(%matmul) {dim = 1, type = "UB"} : (tensor<1x?xi32>) -> (index)
+        %matmul_ub_const = "test.reify_bound"(%matmul) {dim = 1, type = "UB", constant} : (tensor<1x?xi32>) -> (index)
         "test.some_use"(%matmul_ub_const) : (index) -> ()
       }
     }
diff --git a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
index 0c90bcdb42028c..0ba9983723a0a1 100644
--- a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
@@ -83,7 +83,7 @@ func.func @extract_slice_static(%t: tensor<?xf32>) -> index {
 func.func @extract_slice_dynamic_constant(%t: tensor<?xf32>, %sz: index) -> index {
   %0 = tensor.extract_slice %t[2][%sz][1] : tensor<?xf32> to tensor<?xf32>
   // expected-error @below{{could not reify bound}}
-  %1 = "test.reify_constant_bound"(%0) {dim = 0} : (tensor<?xf32>) -> (index)
+  %1 = "test.reify_bound"(%0) {dim = 0, constant} : (tensor<?xf32>) -> (index)
   return %1 : index
 }
 
@@ -95,7 +95,7 @@ func.func @extract_slice_dynamic_constant(%t: tensor<?xf32>, %sz: index) -> inde
 //       CHECK:   return %[[c5]]
 func.func @extract_slice_static_constant(%t: tensor<?xf32>) -> index {
   %0 = tensor.extract_slice %t[2][5][1] : tensor<?xf32> to tensor<5xf32>
-  %1 = "test.reify_constant_bound"(%0) {dim = 0} : (tensor<5xf32>) -> (index)
+  %1 = "test.reify_bound"(%0) {dim = 0, constant} : (tensor<5xf32>) -> (index)
   return %1 : index
 }
 
diff --git a/mlir/test/Dialect/Vector/test-scalable-bounds.mlir b/mlir/test/Dialect/Vector/test-scalable-bounds.mlir
index 245a6f5c13ac3d..d549c5bd1c3785 100644
--- a/mlir/test/Dialect/Vector/test-scalable-bounds.mlir
+++ b/mlir/test/Dialect/Vector/test-scalable-bounds.mlir
@@ -26,8 +26,8 @@ func.func @fixed_size_loop_nest() {
     %min_i = affine.min #map_dim_i(%i)[%c4_vscale]
     scf.for %j = %c0 to %c16 step %c4_vscale {
       %min_j = affine.min #map_dim_j(%j)[%c4_vscale]
-      %bound_i = "test.reify_scalable_bound"(%min_i) {type = "UB", vscale_min = 1, vscale_max = 16} : (index) -> index
-      %bound_j = "test.reify_scalable_bound"(%min_j) {type = "UB", vscale_min = 1, vscale_max = 16} : (index) -> index
+      %bound_i = "test.reify_bound"(%min_i) {type = "UB", vscale_min = 1, vscale_max = 16, scalable} : (index) -> index
+      %bound_j = "test.reify_bound"(%min_j) {type = "UB", vscale_min = 1, vscale_max = 16, scalable} : (index) -> index
       "test.some_use"(%bound_i, %bound_j) : (index, index) -> ()
     }
   }
@@ -58,8 +58,8 @@ func.func @dynamic_size_loop_nest(%dim0: index, %dim1: index) {
     %min_i = affine.min #map_dynamic_dim(%i)[%c4_vscale, %dim0]
     scf.for %j = %c0 to %dim1 step %c4_vscale {
       %min_j = affine.min #map_dynamic_dim(%j)[%c4_vscale, %dim1]
-      %bound_i = "test.reify_scalable_bound"(%min_i) {type = "UB", vscale_min = 1, vscale_max = 16} : (index) -> index
-      %bound_j = "test.reify_scalable_bound"(%min_j) {type = "UB", vscale_min = 1, vscale_max = 16} : (index) -> index
+      %bound_i = "test.reify_bound"(%min_i) {type = "UB", vscale_min = 1, vscale_max = 16, scalable} : (index) -> index
+      %bound_j = "test.reify_bound"(%min_j) {type = "UB", vscale_min = 1, vscale_max = 16, scalable} : (index) -> index
       "test.some_use"(%bound_i, %bound_j) : (index, index) -> ()
     }
   }
@@ -80,7 +80,7 @@ func.func @add_to_vscale() {
   %vscale = vector.vscale
   %c8 = arith.constant 8 : index
   %vscale_plus_c8 = arith.addi %vscale, %c8 : index
-  %bound = "test.reify_scalable_bound"(%vscale_plus_c8) {type = "EQ", vscale_min = 1, vscale_max = 16} : (index) -> index
+  %bound = "test.reify_bound"(%vscale_plus_c8) {type = "EQ", vscale_min = 1, vscale_max = 16, scalable} : (index) -> index
   "test.some_use"(%bound) : (index) -> ()
   return
 }
@@ -94,7 +94,7 @@ func.func @add_to_vscale() {
 //       CHECK:   "test.some_use"(%[[C2]]) : (index) -> ()
 func.func @vscale_fixed_size() {
   %vscale = vector.vscale
-  %bound = "test.reify_scalable_bound"(%vscale) {type = "EQ", vscale_min = 2, vscale_max = 2} : (index) -> index
+  %bound = "test.reify_bound"(%vscale) {type = "EQ", vscale_min = 2, vscale_max = 2, scalable} : (index) -> index
   "test.some_use"(%bound) : (index) -> ()
   return
 }
@@ -107,7 +107,7 @@ func.func @unknown_bound(%a: index) {
   %vscale = vector.vscale
   %vscale_plus_a = arith.muli %vscale, %a : index
   // expected-error @below{{could not reify bound}}
-  %bound = "test.reify_scalable_bound"(%vscale_plus_a) {type = "UB", vscale_min = 1, vscale_max = 16} : (index) -> index
+  %bound = "test.reify_bound"(%vscale_plus_a) {type = "UB", vscale_min = 1, vscale_max = 16, scalable} : (index) -> index
   "test.some_use"(%bound) : (index) -> ()
   return
 }
@@ -134,7 +134,7 @@ func.func @duplicate_vscale_values() {
   %c2_vscale = arith.muli %vscale_1, %c2 : index
   %add = arith.addi %c2_vscale, %c4_vscale : index
 
-  %bound = "test.reify_scalable_bound"(%add) {type = "EQ", vscale_min = 1, vscale_max = 16} : (index) -> index
+  %bound = "test.reify_bound"(%add) {type = "EQ", vscale_min = 1, vscale_max = 16, scalable} : (index) -> index
   "test.some_use"(%bound) : (index) -> ()
   return
 }
@@ -154,7 +154,7 @@ func.func @non_scalable_code() {
   %c0 = arith.constant 0 : index
   scf.for %i = %c0 to %c1024 step %c4 {
     %min_i = affine.min #map_dim_i(%i)[%c4]
-    %bound_i = "test.reify_scalable_bound"(%min_i) {type = "UB", vscale_min = 1, vscale_max = 16} : (index) -> index
+    %bound_i = "test.reify_bound"(%min_i) {type = "UB", vscale_min = 1, vscale_max = 16, scalable} : (index) -> index
     "test.some_use"(%bound_i) : (index) -> ()
   }
   return
diff --git a/mlir/test/lib/Dialect/Affine/CMakeLists.txt b/mlir/test/lib/Dialect/Affine/CMakeLists.txt
index 14960a45d39bab..33cefab9fa2edf 100644
--- a/mlir/test/lib/Dialect/Affine/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Affine/CMakeLists.txt
@@ -30,5 +30,13 @@ add_mlir_library(MLIRAffineTransformsTestPasses
   MLIRSupport
   MLIRMemRefDialect
   MLIRTensorDialect
+  MLIRTestDialect
   MLIRVectorUtils
   )
+
+target_include_directories(MLIRAffineTransformsTestPasses
+  PRIVATE
+  ${CMAKE_CURRENT_SOURCE_DIR}/../Test
+  ${CMAKE_CURRENT_BINARY_DIR}/../Test
+  )
+
diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
index f38631054fb3c1..6730f9b292ad93 100644
--- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -6,6 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "TestDialect.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
 #include "mlir/Dialect/Affine/Transforms/Transforms.h"
@@ -57,31 +58,6 @@ struct TestReifyValueBounds
 
 } // namespace
 
-static FailureOr<BoundType> parseBoundType(const std::string &type) {
-  if (type == "EQ")
-    return BoundType::EQ;
-  if (type == "LB")
-    return BoundType::LB;
-  if (type == "UB")
-    return BoundType::UB;
-  return failure();
-}
-
-static FailureOr<ValueBoundsConstraintSet::ComparisonOperator>
-parseComparisonOperator(const std::string &type) {
-  if (type == "EQ")
-    return ValueBoundsConstraintSet::ComparisonOperator::EQ;
-  if (type == "LT")
-    return ValueBoundsConstraintSet::ComparisonOperator::LT;
-  if (type == "LE")
-    return ValueBoundsConstraintSet::ComparisonOperator::LE;
-  if (type == "GT")
-    return ValueBoundsConstraintSet::ComparisonOperator::GT;
-  if (type == "GE")
-    return ValueBoundsConstraintSet::ComparisonOperator::GE;
-  return failure();
-}
-
 static ValueBoundsConstraintSet::ComparisonOperator
 invertComparisonOperator(ValueBoundsConstraintSet::ComparisonOperator cmp) {
   if (cmp == ValueBoundsConstraintSet::ComparisonOperator::LT)
@@ -101,144 +77,89 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
                                           bool reifyToFuncArgs,
                                           bool useArithOps) {
   IRRewriter rewriter(funcOp.getContext());
-  WalkResult result = funcOp.walk([&](Operation *op) {
-    // Look for test.reify_bound ops.
-    if (op->getName().getStringRef() == "test.reify_bound" ||
-        op->getName().getStringRef() == "test.reify_constant_bound" ||
-        op->getName().getStringRef() == "test.reify_scalable_bound") {
-      if (op->getNumOperands() != 1 || op->getNumResults() != 1 ||
-          !op->getResultTypes()[0].isIndex()) {
-        op->emitOpError("invalid op");
-        return WalkResult::skip();
-      }
-      Value value = op->getOperand(0);
-      if (isa<IndexType>(value.getType()) !=
-          !op->hasAttrOfType<IntegerAttr>("dim")) {
-        // Op should have "dim" attribute if and only if the operand is an
-        // index-typed value.
-        op->emitOpError("invalid op");
-        return WalkResult::skip();
-      }
-
-      // Get bound type.
-      std::string boundTypeStr = "EQ";
-      if (auto boundTypeAttr = op->getAttrOfType<StringAttr>("type"))
-        boundTypeStr = boundTypeAttr.str();
-      auto boundType = parseBoundType(boundTypeStr);
-      if (failed(boundType)) {
-        op->emitOpError("invalid op");
-        return WalkResult::interrupt();
-      }
-
-      // Get shape dimension (if any).
-      auto dim = value.getType().isIndex()
-                     ? std::nullopt
-                     : std::make_optional<int64_t>(
-                           op->getAttrOfType<IntegerAttr>("dim").getInt());
-
-      // Check if a constant was requested.
-      bool constant =
-          op->getName().getStringRef() == "test.reify_constant_bound";
-
-      bool scalable = !constant && op->getName().getStringRef() ==
-                                       "test.reify_scalable_bound";
-
-      // Prepare stop condition. By default, reify in terms of the op's
-      // operands. No stop condition is used when a constant was requested.
-      std::function<bool(Value, std::optional<int64_t>,
-                         ValueBoundsConstraintSet & cstr)>
-          stopCondition = [&](Value v, std::optional<int64_t> d,
-                              ValueBoundsConstraintSet &cstr) {
-            // Reify in terms of SSA values that are different from `value`.
-            return v != value;
-          };
-      if (reifyToFuncArgs) {
-        // Reify in terms of function block arguments.
-        stopCondition = [](Value v, std::optional<int64_t> d,
-                           ValueBoundsConstraintSet &cstr) {
-          auto bbArg = dyn_cast<BlockArgument>(v);
-          if (!bbArg)
-            return false;
-          return isa<FunctionOpInterface>(
-              bbArg.getParentBlock()->getParentOp());
+  WalkResult result = funcOp.walk([&](test::ReifyBoundOp op) {
+    auto boundType = op.getBoundType();
+    Value value = op.getVar();
+    std::optional<int64_t> dim = op.getDim();
+    bool constant = op.getConstant();
+    bool scalable = op.getScalable();
+
+    // Prepare stop condition. By default, reify in terms of the op's
+    // operands. No stop condition is used when a constant was requested.
+    std::function<bool(Value, std::optional<int64_t>,
+                       ValueBoundsConstraintSet & cstr)>
+        stopCondition = [&](Value v, std::optional<int64_t> d,
+                            ValueBoundsConstraintSet &cstr) {
+          // Reify in terms of SSA values that are different from `value`.
+          return v != value;
         };
-      }
+    if (reifyToFuncArgs) {
+      // Reify in terms of function block arguments.
+      stopCondition = [](Value v, std::optional<int64_t> d,
+                         ValueBoundsConstraintSet &cstr) {
+        auto bbArg = dyn_cast<BlockArgument>(v);
+        if (!bbArg)
+          return false;
+        return isa<FunctionOpInterface>(bbArg.getParentBlock()->getParentOp());
+      };
+    }
 
-      // Reify value bound
-      rewriter.setInsertionPointAfter(op);
-      FailureOr<OpFoldResult> reified = failure();
-      if (constant) {
-        auto reifiedConst = ValueBoundsConstraintSet::computeConstantBound(
-            *boundType, value, dim, /*stopCondition=*/nullptr);
-        if (succeeded(reifiedConst))
-          reified =
-              FailureOr<OpFoldResult>(rewriter.getIndexAttr(*reifiedConst));
-      } else if (scalable) {
-        unsigned vscaleMin = 0;
-        unsigned vscaleMax = 0;
-        if (auto attr = "vscale_min"; op->hasAttrOfType<IntegerAttr>(attr)) {
-          vscaleMin = unsigned(op->getAttrOfType<IntegerAttr>(attr).getInt());
-        } else {
-          op->emitOpError("expected `vscale_min` to be provided");
-          return WalkResult::skip();
+    // Reify value bound
+    rewriter.setInsertionPointAfter(op);
+    FailureOr<OpFoldResult> reified = failure();
+    if (constant) {
+      auto reifiedConst = ValueBoundsConstraintSet::computeConstantBound(
+          boundType, value, dim, /*stopCondition=*/nullptr);
+      if (succeeded(reifiedConst))
+        reified = FailureOr<OpFoldResult>(rewriter.getIndexAttr(*reifiedConst));
+    } else if (scalable) {
+      auto loc = op->getLoc();
+      auto reifiedScalable =
+          vector::ScalableValueBoundsConstraintSet::computeScalableBound(
+              value, dim, *op.getVscaleMin(), *op.getVscaleMax(), boundType);
+      if (succeeded(reifiedScalable)) {
+        SmallVector<std::pair<Value, std::optional<int64_t>>, 1> vscaleOperand;
+        if (reifiedScalable->map.getNumInputs() == 1) {
+          // The only possible input to the bound is vscale.
+          vscaleOperand.push_back(std::make_pair(
+              rewriter.create<vector::VectorScaleOp>(loc), std::nullopt));
         }
-        if (auto attr = "vscale_max"; op->hasAttrOfType<IntegerAttr>(attr)) {
-          vscaleMax = unsigned(op->getAttrOfType<IntegerAttr>(attr).getInt());
+        reified = affine::materializeComputedBound(
+            rewriter, loc, reifiedScalable->map, vscaleOperand);
+      }
+    } else {
+      if (dim) {
+        if (useArithOps) {
+          reified = arith::reifyShapedValueDimBound(
+              rewriter, op->getLoc(), boundType, value, *dim, stopCondition);
         } else {
-          op->emitOpError("expected `vscale_max` to be provided");
-          return WalkResult::skip();
-        }
-
-        auto loc = op->getLoc();
-        auto reifiedScalable =
-            vector::ScalableValueBoundsConstraintSet::computeScalableBound(
-                value, dim, vscaleMin, vscaleMax, *boundType);
-        if (succeeded(reifiedScalable)) {
-          SmallVector<std::pair<Value, std::optional<int64_t>>, 1>
-              vscaleOperand;
-          if (reifiedScalable->map.getNumInputs() == 1) {
-            // The only possible input to the bound is vscale.
-            vscaleOperand.push_back(std::make_pair(
-                rewriter.create<vector::VectorScaleOp>(loc), std::nullopt));
-          }
-          reified = affine::materializeComputedBound(
-              rewriter, loc, reifiedScalable->map, vscaleOperand);
+          reified = reifyShapedValueDimBound(rewriter, op->getLoc(), boundType,
+                                             value, *dim, stopCondition);
         }
       } else {
-        if (dim) {
-          if (useArithOps) {
-            reified = arith::reifyShapedValueDimBound(
-                rewriter, op->getLoc(), *boundType, value, *dim, stopCondition);
-          } else {
-            reified = reifyShapedValueDimBound(
-                rewriter, op->getLoc(), *boundType, value, *dim, stopCondition);
-          }
+        if (useArithOps) {
+          reified = arith::reifyIndexValueBound(
+              rewriter, op->getLoc(), boundType, value, stopCondition);
         } else {
-          if (useArithOps) {
-            reified = arith::reifyIndexValueBound(
-                rewriter, op->getLoc(), *boundType, value, stopCondition);
-          } else {
-            reified = reifyIndexValueBound(rewriter, op->getLoc(), *boundType,
-                                           value, stopCondition);
-          }
+          reified = reifyIndexValueBound(rewriter, op->getLoc(), boundType,
+                                         value, stopCondition);
         }
       }
-      if (failed(reified)) {
-        op->emitOpError("could not reify bound");
-        return WalkResult::interrupt();
-      }
+    }
+    if (failed(reified)) {
+      op->emitOpError("could not reify bound");
+      return WalkResult::interrupt();
+    }
 
-      // Replace the op with the reified bound.
-      if (auto val = llvm::dyn_cast_if_present<Value>(*reified)) {
-        rewriter.replaceOp(op, val);
-        return WalkResult::skip();
-      }
-      Value constOp = rewriter.create<arith::ConstantIndexOp>(
-          op->getLoc(), cast<IntegerAttr>(reified->get<Attribute>()).getInt());
-      rewriter.replaceOp(op, constOp);
+    // Replace the op with the reified bound.
+    if (auto val = llvm::dyn_cast_if_present<Value>(*reified)) {
+      rewriter.replaceOp(op, val);
       return WalkResult::skip();
     }
-    return WalkResult::advance();
+    Value constOp = rewriter.create<arith::ConstantIndexOp>(
+        op->getLoc(), cast<IntegerAttr>(reified->get<Attribute>()).getInt());
+    rewriter.replaceOp(op, constOp);
+    return WalkResult::skip();
   });
   return failure(result.wasInterrupted());
 }
@@ -246,60 +167,42 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
 /// Look for "test.compare" ops and emit errors/remarks.
 static LogicalResult testEquality(func::FuncOp funcOp) {
   IRRewriter rewriter(funcOp.getContext());
-  WalkResult result = funcOp.walk([&](Operation *op) {
-    // Look for test.compare ops.
-    if (op->getName().getStringRef() == "test.compare") {
-      if (op->getNumOperands() != 2 || !op->getOperand(0).getType().isIndex() ||
-          !op->getOperand(1).getType().isIndex()) {
-        op->emitOpError("invalid op");
-        return WalkResult::skip();
-      }
-
-      // Get comparison operator.
-      std::string cmpStr = "EQ";
-      if (auto cmpAttr = op->getAttrOfType<StringAttr>("cmp"))
-        cmpStr = cmpAttr.str();
-      auto cmpType = parseComparisonOperator(cmpStr);
-      if (failed(cmpType)) {
-        op->emitOpError("invalid comparison operator");
+  WalkResult result = funcOp.walk([&](test::CompareOp op) {
+    auto cmpType = op.getComparisonOperator();
+    if (op.getCompose()) {
+      if (cmpType != ValueBoundsConstraintSet::EQ) {
+        op->emitOpError(
+            "comparison operator must be EQ when 'composed' is specified");
         return WalkResult::interrupt();
       }
-
-      if (op->hasAttr("compose")) {
-        if (cmpType != ValueBoundsConstraintSet::EQ) {
-          op->emitOpError(
-              "comparison operator must be EQ when 'composed' is specified");
-          return WalkResult::interrupt();
-        }
-        FailureOr<int64_t> delta = affine::fullyComposeAndComputeConstantDelta(
-            op->getOperand(0), op->getOperand(1));
-        if (failed(delta)) {
-          op->emitError("could not determine equality");
-        } else if (*delta == 0) {
-          op->emitRemark("equal");
-        } else {
-          op->emitRemark("different");
-        }
-        return WalkResult::advance();
-      }
-
-      auto compare = [&](ValueBoundsConstraintSet::ComparisonOperator cmp) {
-        return ValueBoundsConstraintSet::compare(
-            /*lhs=*/op->getOperand(0), /*lhsDim=*/std::nullopt, cmp,
-            /*rhs=*/op->getOperand(1), /*rhsDim=*/std::nullopt);
-      };
-      if (compare(*cmpType)) {
-        op->emitRemark("true");
-      } else if (*cmpType != ValueBoundsConstraintSet::EQ &&
-                 compare(invertComparisonOperator(*cmpType))) {
-        op->emitRemark("false");
-      } else if (*cmpType == ValueBoundsConstraintSet::EQ &&
-                 (compare(ValueBoundsConstraintSet::ComparisonOperator::LT) ||
-                  compare(ValueBoundsConstraintSet::ComparisonOperator::GT))) {
-        op->emitRemark("false");
+      FailureOr<int64_t> delta = affine::fullyComposeAndComputeConstantDelta(
+          op->getOperand(0), op->getOperand(1));
+      if (failed(delta)) {
+        op->emitError("could not determine equality");
+      } else if (*delta == 0) {
+        op->emitRemark("equal");
       } else {
-        op->emitError("unknown");
+        op->emitRemark("different");
       }
+      return WalkResult::advance();
+    }
+
+    auto compare = [&](ValueBoundsConstraintSet::ComparisonOperator cmp) {
+      return ValueBoundsConstraintSet::compare(
+          /*lhs=*/op.getLhs(), /*lhsDim=*/std::nullopt, cmp,
+          /*rhs=*/op.getRhs(), /*rhsDim=*/std::nullopt);
+    };
+    if (compare(cmpType)) {
+      op->emitRemark("true");
+    } else if (cmpType != ValueBoundsConstraintSet::EQ &&
+               compare(invertComparisonOperator(cmpType))) {
+      op->emitRemark("false");
+    } else if (cmpType == ValueBoundsConstraintSet::EQ &&
+               (compare(ValueBoundsConstraintSet::ComparisonOperator::LT) ||
+                compare(ValueBoundsConstraintSet::ComparisonOperator::GT))) {
+      op->emitRemark("false");
+    } else {
+      op->emitError("unknown");
     }
     return WalkResult::advance();
   });
diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index 47ddcf6524748c..d246c0492a3bd5 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -85,6 +85,7 @@ add_mlir_library(MLIRTestDialect
   MLIRTensorDialect
   MLIRTransformUtils
   MLIRTransforms
+  MLIRValueBoundsOpInterface
 )
 
 add_mlir_translation_library(MLIRTestFromLLVMIRTranslation
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 74378633032ce7..25c5190ca0ef3a 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -516,6 +516,54 @@ static void printOptionalCustomParser(AsmPrinter &p, Operation *,
   p.printAttribute(result);
 }
 
+//===----------------------------------------------------------------------===//
+// ReifyBoundOp
+//===----------------------------------------------------------------------===//
+
+::mlir::presburger::BoundType ReifyBoundOp::getBoundType() {
+  if (getType() == "EQ")
+    return ::mlir::presburger::BoundType::EQ;
+  if (getType() == "LB")
+    return ::mlir::presburger::BoundType::LB;
+  if (getType() == "UB")
+    return ::mlir::presburger::BoundType::UB;
+  llvm_unreachable("invalid bound type");
+}
+
+LogicalResult ReifyBoundOp::verify() {
+  if (isa<ShapedType>(getVar().getType())) {
+    if (!getDim().has_value())
+      return emitOpError("expected 'dim' attribute for shaped type variable");
+  } else if (getVar().getType().isIndex()) {
+    if (getDim().has_value())
+      return emitOpError("unexpected 'dim' attribute for index variable");
+  } else {
+    return emitOpError("expected index-typed variable or shape type variable");
+  }
+  if (getConstant() && getScalable())
+    return emitOpError("'scalable' and 'constant' are mutually exlusive");
+  if (getScalable() != getVscaleMin().has_value())
+    return emitOpError("expected 'vscale_min' if and only if 'scalable'");
+  if (getScalable() != getVscaleMax().has_value())
+    return emitOpError("expected 'vscale_min' if and only if 'scalable'");
+  return success();
+}
+
+::mlir::ValueBoundsConstraintSet::ComparisonOperator
+CompareOp::getComparisonOperator() {
+  if (getCmp() == "EQ")
+    return ValueBoundsConstraintSet::ComparisonOperator::EQ;
+  if (getCmp() == "LT")
+    return ValueBoundsConstraintSet::ComparisonOperator::LT;
+  if (getCmp() == "LE")
+    return ValueBoundsConstraintSet::ComparisonOperator::LE;
+  if (getCmp() == "GT")
+    return ValueBoundsConstraintSet::ComparisonOperator::GT;
+  if (getCmp() == "GE")
+    return ValueBoundsConstraintSet::ComparisonOperator::GE;
+  llvm_unreachable("invalid comparison operator");
+}
+
 //===----------------------------------------------------------------------===//
 // Test removing op with inner ops.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h
index 4ba28c47ed1c33..d5b2fbeafc4104 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.h
+++ b/mlir/test/lib/Dialect/Test/TestDialect.h
@@ -41,6 +41,7 @@
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
 
 #include <memory>
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index c88f85b8b6b361..0ab2427038add3 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2184,6 +2184,53 @@ def ForwardBufferOp : TEST_Op<"forward_buffer", [Pure]> {
   let results = (outs AnyRankedOrUnrankedMemRef:$result);
 }
 
+//===----------------------------------------------------------------------===//
+// Test ValueBoundsOpInterface
+//===----------------------------------------------------------------------===//
+
+def ReifyBoundOp : TEST_Op<"reify_bound", [Pure]> {
+  let description = [{
+    Reify a bound for the given index-typed value or dimension size of a shaped
+    value. "LB", "EQ" and "UB" bounds are supported. If `scalable` is set, a
+    bound in terms of "vector.vscale" is computed. `vscale_min` and `vscale_max`
+    must be specified in that case.
+  }];
+
+  let arguments = (ins AnyType:$var,
+                       OptionalAttr<I64Attr>:$dim,
+                       DefaultValuedAttr<StrAttr, "\"EQ\"">:$type,
+                       UnitAttr:$constant,
+                       UnitAttr:$scalable,
+                       OptionalAttr<I64Attr>:$vscale_min,
+                       OptionalAttr<I64Attr>:$vscale_max);
+  let results = (outs Index:$result);
+
+  let extraClassDeclaration = [{
+    ::mlir::presburger::BoundType getBoundType();
+  }];
+
+  let hasVerifier = 1;
+}
+
+def CompareOp : TEST_Op<"compare"> {
+  let description = [{
+    Compare `lhs` and `rhs`. A remark is emitted which indicates whether the
+    specified comparison operator was proven to hold. The remark also indicates
+    whether the opposite comparison operator was proven to hold.
+  }];
+
+  let arguments = (ins Index:$lhs,
+                       Index:$rhs,
+                       DefaultValuedAttr<StrAttr, "\"EQ\"">:$cmp,
+                       UnitAttr:$compose);
+  let results = (outs);
+
+  let extraClassDeclaration = [{
+    ::mlir::ValueBoundsConstraintSet::ComparisonOperator
+        getComparisonOperator();
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // Test RegionBranchOpInterface
 //===----------------------------------------------------------------------===//
diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index 684b59e7f62f65..dc5f4047c286db 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -421,6 +421,7 @@ cc_library(
         "//mlir:TranslateLib",
         "//mlir:TransformUtils",
         "//mlir:Transforms",
+        "//mlir:ValueBoundsOpInterface",
         "//mlir:ViewLikeInterface",
     ],
 )
@@ -644,6 +645,7 @@ cc_library(
         "lib/Dialect/Affine/*.cpp",
     ]),
     deps = [
+        ":TestDialect",
         "//llvm:Support",
         "//mlir:AffineAnalysis",
         "//mlir:AffineDialect",



More information about the Mlir-commits mailing list