[Mlir-commits] [mlir] [mlir][tosa] Add more verifiers for the following operators (PR #127923)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 27 11:25:20 PST 2025


================
@@ -472,6 +472,104 @@ LogicalResult tosa::AvgPool2dOp::verify() {
   return emitOpError("input/output element types are incompatible.");
 }
 
+LogicalResult tosa::CastOp::verify() {
+  mlir::Type inputETy =
+      llvm::cast<ShapedType>(getInput().getType()).getElementType();
+  if (auto inputQuantType =
+          llvm::dyn_cast<mlir::quant::QuantizedType>(inputETy)) {
+    inputETy = inputQuantType.getStorageType();
+  }
+  mlir::Type outputETy =
+      llvm::cast<ShapedType>(getOutput().getType()).getElementType();
+  if (auto outputQuantType =
+          llvm::dyn_cast<mlir::quant::QuantizedType>(outputETy)) {
+    outputETy = outputQuantType.getStorageType();
+  }
+
+  // input element type: bool
+  if (inputETy.isInteger(1)) {
+    if (outputETy.isInteger(8) || outputETy.isInteger(16) ||
+        outputETy.isInteger(32)) {
+      return success();
+    }
+  }
+  // input element type: int8
+  if (inputETy.isInteger(8)) {
+    if (outputETy.isInteger(1) || outputETy.isInteger(16) ||
+        outputETy.isInteger(32) || outputETy.isF16() || outputETy.isBF16() ||
+        outputETy.isF32()) {
+      return success();
+    }
+  }
+  // input element type: int16
+  if (inputETy.isInteger(16)) {
+    if (outputETy.isInteger(1) || outputETy.isInteger(8) ||
+        outputETy.isInteger(32) || outputETy.isF16() || outputETy.isBF16() ||
+        outputETy.isF32()) {
+      return success();
+    }
+  }
+  // input element type: int32
+  if (inputETy.isInteger(32)) {
+    if (outputETy.isInteger(1) || outputETy.isInteger(8) ||
+        outputETy.isInteger(16) || outputETy.isF16() || outputETy.isBF16() ||
+        outputETy.isF32()) {
+      return success();
+    }
+  }
+  // input element type: bf16 or fp16
+  if (inputETy.isBF16() || inputETy.isF16()) {
+    if (outputETy.isInteger(8) || outputETy.isInteger(16) ||
+        outputETy.isInteger(32) || llvm::isa<Float8E5M2Type>(outputETy) ||
+        llvm::isa<Float8E4M3FNType>(outputETy) || outputETy.isF32()) {
+      return success();
+    }
+  }
+  // input element type: f8e4m3 or f8e5m2
+  if (llvm::isa<Float8E4M3FNType>(inputETy) ||
+      llvm::isa<Float8E5M2Type>(inputETy)) {
+    if (outputETy.isF16() || outputETy.isBF16() || outputETy.isF32()) {
+      return success();
+    }
+  }
+  // input element type: fp32
+  if (inputETy.isF32()) {
+    if (outputETy.isInteger(8) || outputETy.isInteger(16) ||
+        outputETy.isInteger(32) || llvm::isa<Float8E5M2Type>(outputETy) ||
+        llvm::isa<Float8E4M3FNType>(outputETy) || outputETy.isF16() ||
+        outputETy.isBF16()) {
+      return success();
+    }
+  }
+
+  // following are outside of TOSA Spec
----------------
Jerry-Ge wrote:

i think those should be moved to the validation pass 

https://github.com/llvm/llvm-project/pull/127923


More information about the Mlir-commits mailing list