[Mlir-commits] [mlir] 829733a - [mlir] Fix SameOperandsAndResultType to check encoding.

Jacques Pienaar llvmlistbot at llvm.org
Wed Dec 21 09:49:25 PST 2022


Author: Jacques Pienaar
Date: 2022-12-21T09:49:18-08:00
New Revision: 829733af4ac2895543797443c82f1f1709472c4f

URL: https://github.com/llvm/llvm-project/commit/829733af4ac2895543797443c82f1f1709472c4f
DIFF: https://github.com/llvm/llvm-project/commit/829733af4ac2895543797443c82f1f1709472c4f.diff

LOG: [mlir] Fix SameOperandsAndResultType to check encoding.

Encoding was accidentally left out here even though it forms part of the type.
This is small tightening step and I'll look at follow on to tighten more.

Differential Revision: https://reviews.llvm.org/D140445

Added: 
    

Modified: 
    mlir/lib/IR/Operation.cpp
    mlir/test/IR/traits.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 219f1e21f379f..d44d0b199efcb 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -893,17 +893,30 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) {
 
   auto type = op->getResult(0).getType();
   auto elementType = getElementTypeOrSelf(type);
+  Attribute encoding = nullptr;
+  if (auto rankedType = dyn_cast<RankedTensorType>(type))
+    encoding = rankedType.getEncoding();
   for (auto resultType : llvm::drop_begin(op->getResultTypes())) {
     if (getElementTypeOrSelf(resultType) != elementType ||
         failed(verifyCompatibleShape(resultType, type)))
       return op->emitOpError()
              << "requires the same type for all operands and results";
+    if (encoding)
+      if (auto rankedType = dyn_cast<RankedTensorType>(resultType);
+          encoding != rankedType.getEncoding())
+        return op->emitOpError()
+               << "requires the same encoding for all operands and results";
   }
   for (auto opType : op->getOperandTypes()) {
     if (getElementTypeOrSelf(opType) != elementType ||
         failed(verifyCompatibleShape(opType, type)))
       return op->emitOpError()
              << "requires the same type for all operands and results";
+    if (encoding)
+      if (auto rankedType = dyn_cast<RankedTensorType>(opType);
+          encoding != rankedType.getEncoding())
+        return op->emitOpError()
+               << "requires the same encoding for all operands and results";
   }
   return success();
 }

diff  --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir
index 80e0d4c8c5e91..ddba1171649c9 100644
--- a/mlir/test/IR/traits.mlir
+++ b/mlir/test/IR/traits.mlir
@@ -174,6 +174,14 @@ func.func @failedSameOperandAndResultType_operand_result_mismatch(%t10 : tensor<
 
 // -----
 
+func.func @failedSameOperandAndResultType_encoding_mismatch(%t10 : tensor<10xf32>, %t20 : tensor<10xf32>) {
+  // expected-error at +1 {{requires the same encoding for all operands and results}}
+  "test.same_operand_and_result_type"(%t10, %t20) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32, "enc">
+  return
+}
+
+// -----
+
 func.func @failedElementwiseMappable_
diff erent_rankedness(%arg0: tensor<?xf32>, %arg1: tensor<*xf32>) {
   // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type}}
   %0 = "test.elementwise_mappable"(%arg0, %arg1) : (tensor<?xf32>, tensor<*xf32>) -> tensor<*xf32>


        


More information about the Mlir-commits mailing list