[Mlir-commits] [mlir] 28fe1a4 - [mlir] Add trait SameOperandsAndResultRank

Eric Kunze llvmlistbot at llvm.org
Thu Sep 7 10:01:53 PDT 2023


Author: Tai Ly
Date: 2023-09-07T16:55:54Z
New Revision: 28fe1a4e5e8af39a6a0fa253b3538cb0905069dc

URL: https://github.com/llvm/llvm-project/commit/28fe1a4e5e8af39a6a0fa253b3538cb0905069dc
DIFF: https://github.com/llvm/llvm-project/commit/28fe1a4e5e8af39a6a0fa253b3538cb0905069dc.diff

LOG: [mlir] Add trait SameOperandsAndResultRank

This adds a native op trait SameOperandsAndResultRank
and associated verifier that checks that an operator's
operands and result types have same ranks if their ranks
are known.

Signed-off-by: Tai Ly <tai.ly at arm.com>
Change-Id: I2d536f77be10f3710d0c8d84c907ff492a984fda

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D156369

Added: 
    

Modified: 
    mlir/include/mlir/IR/OpDefinition.h
    mlir/include/mlir/Interfaces/InferTypeOpInterface.td
    mlir/lib/IR/Operation.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-tblgen/types.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 67d923adbf9374b..84ba46f4d6f3ec1 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -341,6 +341,7 @@ LogicalResult verifySameOperandsAndResultShape(Operation *op);
 LogicalResult verifySameOperandsElementType(Operation *op);
 LogicalResult verifySameOperandsAndResultElementType(Operation *op);
 LogicalResult verifySameOperandsAndResultType(Operation *op);
+LogicalResult verifySameOperandsAndResultRank(Operation *op);
 LogicalResult verifyResultsAreBoolLike(Operation *op);
 LogicalResult verifyResultsAreFloatLike(Operation *op);
 LogicalResult verifyResultsAreSignlessIntegerLike(Operation *op);
@@ -1114,6 +1115,17 @@ class SameOperandsAndResultType
   }
 };
 
+/// This class verifies that op has same ranks for all
+/// operands and results types, if known.
+template <typename ConcreteType>
+class SameOperandsAndResultRank
+    : public TraitBase<ConcreteType, SameOperandsAndResultRank> {
+public:
+  static LogicalResult verifyTrait(Operation *op) {
+    return impl::verifySameOperandsAndResultRank(op);
+  }
+};
+
 /// This class verifies that any results of the specified op have a boolean
 /// type, a vector thereof, or a tensor thereof.
 template <typename ConcreteType>

diff  --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
index 1ceaf25a994e01a..54c1c13fd029dbc 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
@@ -369,4 +369,7 @@ def ReifyRankedShapedTypeOpInterface :
 // TODO: Change from hard coded to utilizing type inference trait.
 def SameOperandsAndResultType : NativeOpTrait<"SameOperandsAndResultType">;
 
+// Op has the same ranks for all operands and results types, if known.
+def SameOperandsAndResultRank : NativeOpTrait<"SameOperandsAndResultRank">;
+
 #endif // MLIR_INFERTYPEOPINTERFACE

diff  --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 888c146f49539f6..ef98a89f4bb49b6 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -1082,6 +1082,51 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) {
   return success();
 }
 
+LogicalResult OpTrait::impl::verifySameOperandsAndResultRank(Operation *op) {
+  if (failed(verifyAtLeastNOperands(op, 1)))
+    return failure();
+
+  // delegate function that returns true if type is a shaped type with known
+  // rank
+  auto hasRank = [](const Type type) {
+    if (auto shaped_type = dyn_cast<ShapedType>(type))
+      return shaped_type.hasRank();
+
+    return false;
+  };
+
+  auto rankedOperandTypes =
+      llvm::make_filter_range(op->getOperandTypes(), hasRank);
+  auto rankedResultTypes =
+      llvm::make_filter_range(op->getResultTypes(), hasRank);
+
+  // If all operands and results are unranked, then no further verification.
+  if (rankedOperandTypes.empty() && rankedResultTypes.empty())
+    return success();
+
+  // delegate function that returns rank of shaped type with known rank
+  auto getRank = [](const Type type) {
+    return type.cast<ShapedType>().getRank();
+  };
+
+  auto rank = !rankedOperandTypes.empty() ? getRank(*rankedOperandTypes.begin())
+                                          : getRank(*rankedResultTypes.begin());
+
+  for (const auto type : rankedOperandTypes) {
+    if (rank != getRank(type)) {
+      return op->emitOpError("operands don't have matching ranks");
+    }
+  }
+
+  for (const auto type : rankedResultTypes) {
+    if (rank != getRank(type)) {
+      return op->emitOpError("result type has 
diff erent rank than operands");
+    }
+  }
+
+  return success();
+}
+
 LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) {
   Block *block = op->getBlock();
   // Verify that the operation is at the end of the respective parent block.

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 2d7f5b0043ba0f6..9ceadab8fa4a086 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -692,6 +692,12 @@ def OperandZeroAndResultHaveSameRank :
   let results = (outs AnyShaped:$res);
 }
 
+def OperandsAndResultHaveSameRank :
+    TEST_Op<"operands_and_result_have_same_rank", [SameOperandsAndResultRank]> {
+  let arguments = (ins AnyShaped:$x, AnyShaped:$y);
+  let results = (outs AnyShaped:$res);
+}
+
 def OperandZeroAndResultHaveSameShape :
     TEST_Op<"operand0_and_result_have_same_shape",
             [AllShapesMatch<["x", "res"]>]> {

diff  --git a/mlir/test/mlir-tblgen/types.mlir b/mlir/test/mlir-tblgen/types.mlir
index d36a390dff8f185..7652a87037c9295 100644
--- a/mlir/test/mlir-tblgen/types.mlir
+++ b/mlir/test/mlir-tblgen/types.mlir
@@ -377,6 +377,33 @@ func.func @same_rank_failure(%arg0: tensor<1x2xi32>, %arg1: tensor<1x2xf32>) {
 
 // -----
 
+// CHECK-LABEL: same_rank_if_known_success
+func.func @same_rank_if_known_success(%t1xi : tensor<1xi32>, %t2xf : tensor<2xf32>, %m3xi : memref<3xi32>, %t1x2xf : tensor<1x2xf32>, %tuxi : tensor<*xi32>) {
+  %0 = "test.operands_and_result_have_same_rank"(%t1xi, %t2xf) : (tensor<1xi32>, tensor<2xf32>) -> (tensor<3xf64>)
+  %1 = "test.operands_and_result_have_same_rank"(%t1xi, %m3xi) : (tensor<1xi32>, memref<3xi32>) -> (tensor<3xi64>)
+  %3 = "test.operands_and_result_have_same_rank"(%tuxi, %t2xf) : (tensor<*xi32>, tensor<2xf32>) -> (tensor<2xf32>)
+  %4 = "test.operands_and_result_have_same_rank"(%t1x2xf, %tuxi) : (tensor<1x2xf32>, tensor<*xi32>) -> (tensor<1x2xf64>)
+  return
+}
+
+// -----
+
+func.func @same_rank_if_known_failure(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) {
+  // expected-error at +1 {{operands don't have matching ranks}}
+  %0 = "test.operands_and_result_have_same_rank"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xf32>) -> (tensor<*xf32>)
+  return
+}
+
+// -----
+
+func.func @same_rank_if_known_failure(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) {
+  // expected-error at +1 {{result type has 
diff erent rank than operands}}
+  %0 = "test.operands_and_result_have_same_rank"(%arg1, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2x3xf32>)
+  return
+}
+
+// -----
+
 // CHECK-LABEL: same_shape_success
 func.func @same_shape_success(%t2x3: tensor<2x3xi32>, %m2x3: memref<2x3xf32>, %v2x3 : vector<2x3xi32>, %t4x5 : tensor<4x5xi32>) {
   "test.operand0_and_result_have_same_shape"(%t2x3, %t4x5) : (tensor<2x3xi32>, tensor<4x5xi32>) -> (tensor<2x3xf32>)


        


More information about the Mlir-commits mailing list