[Mlir-commits] [mlir] [mlir][vector] Add `extract(transpose(broadcast(x)))` canonicalization (PR #72616)

Jakub Kuderski llvmlistbot at llvm.org
Thu Dec 14 07:38:10 PST 2023


================
@@ -5421,20 +5490,19 @@ LogicalResult vector::TransposeOp::verify() {
   if (vectorType.getRank() != rank)
     return emitOpError("vector result rank mismatch: ") << rank;
   // Verify transposition array.
-  auto transpAttr = getTransp().getValue();
-  int64_t size = transpAttr.size();
+  ArrayRef<int64_t> perm = getPermutation();
+  int64_t size = perm.size();
   if (rank != size)
     return emitOpError("transposition length mismatch: ") << size;
   SmallVector<bool, 8> seen(rank, false);
-  for (const auto &ta : llvm::enumerate(transpAttr)) {
-    int64_t i = llvm::cast<IntegerAttr>(ta.value()).getInt();
-    if (i < 0 || i >= rank)
-      return emitOpError("transposition index out of range: ") << i;
-    if (seen[i])
-      return emitOpError("duplicate position index: ") << i;
-    seen[i] = true;
-    if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(i))
-      return emitOpError("dimension size mismatch at: ") << i;
+  for (const auto &ta : llvm::enumerate(perm)) {
+    if (ta.value() < 0 || ta.value() >= rank)
----------------
kuhar wrote:

while you are at it, it would be nice to rewrite this with structured bindings for better variable names
```suggestion
  for (auto [idx, XYZ] : llvm::enumerate(perm)) {
    if (XYZ < 0 || XYZ >= rank)
```

https://github.com/llvm/llvm-project/pull/72616


More information about the Mlir-commits mailing list