[Mlir-commits] [mlir] [mlir][vector] Add `extract(transpose(broadcast(x)))` canonicalization (PR #72616)
Jakub Kuderski
llvmlistbot at llvm.org
Thu Dec 14 07:38:10 PST 2023
================
@@ -5421,20 +5490,19 @@ LogicalResult vector::TransposeOp::verify() {
if (vectorType.getRank() != rank)
return emitOpError("vector result rank mismatch: ") << rank;
// Verify transposition array.
- auto transpAttr = getTransp().getValue();
- int64_t size = transpAttr.size();
+ ArrayRef<int64_t> perm = getPermutation();
+ int64_t size = perm.size();
if (rank != size)
return emitOpError("transposition length mismatch: ") << size;
SmallVector<bool, 8> seen(rank, false);
- for (const auto &ta : llvm::enumerate(transpAttr)) {
- int64_t i = llvm::cast<IntegerAttr>(ta.value()).getInt();
- if (i < 0 || i >= rank)
- return emitOpError("transposition index out of range: ") << i;
- if (seen[i])
- return emitOpError("duplicate position index: ") << i;
- seen[i] = true;
- if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(i))
- return emitOpError("dimension size mismatch at: ") << i;
+ for (const auto &ta : llvm::enumerate(perm)) {
+ if (ta.value() < 0 || ta.value() >= rank)
----------------
kuhar wrote:
while you are at it, it would be nice to rewrite this with structured bindings for better variable names
```suggestion
for (auto [idx, XYZ] : llvm::enumerate(perm)) {
if (XYZ < 0 || XYZ >= rank)
```
https://github.com/llvm/llvm-project/pull/72616
More information about the Mlir-commits
mailing list