[Mlir-commits] [mlir] [mlir][tensor] Add a tensor.concat operation (PR #72779)

Nicolas Vasilache llvmlistbot at llvm.org
Mon Nov 20 04:50:29 PST 2023


================
@@ -471,6 +471,202 @@ void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// ConcatOp
+//===----------------------------------------------------------------------===//
+
+FailureOr<RankedTensorType> ConcatOp::inferResultType(int64_t dim,
+                                                      TypeRange inputTypes) {
+  if (dim < 0)
+    return failure();
+
+  if (inputTypes.empty())
+    return failure();
+
+  RankedTensorType init = dyn_cast<RankedTensorType>(inputTypes[0]);
+  if (!init)
+    return failure();
+
+  // The tensor rank must be greater than the concatenation dim.
+  int64_t concatRank = init.getRank();
+  if (concatRank <= dim)
+    return failure();
+
+  SmallVector<int64_t> sizes(init.getShape());
+  Type elementType = init.getElementType();
+  for (Type type : inputTypes.drop_front()) {
----------------
nicolasvasilache wrote:

If it were me, I would probably structure this differently:
1. for all `ranks != dim`, if `llvm::any_of tensor.shape()[rank]` mismatches return failure with a proper emitOpError msg.
2. for `dim`, expose/evolve and reuse "saturated_arith" helpers from here: https://github.com/llvm/llvm-project/blob/cfee7152d4eb673976b51b831295dcf5b1811634/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp#L32

It seems people routinely rewrite logic around isDynamic etc that we would be better off commonalizing.

https://github.com/llvm/llvm-project/pull/72779


More information about the Mlir-commits mailing list