[Mlir-commits] [mlir] 0017dc2 - [mlir][Vector] Use llvm::zip to avoid assertion failed.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Dec 13 18:27:03 PST 2022


Author: jacquesguan
Date: 2022-12-14T10:26:54+08:00
New Revision: 0017dc2d0ff89bd9fb93fd25ba787ba66a4de356

URL: https://github.com/llvm/llvm-project/commit/0017dc2d0ff89bd9fb93fd25ba787ba66a4de356
DIFF: https://github.com/llvm/llvm-project/commit/0017dc2d0ff89bd9fb93fd25ba787ba66a4de356.diff

LOG: [mlir][Vector] Use llvm::zip to avoid assertion failed.

This patch fixes the issue https://github.com/llvm/llvm-project/issues/59455.
We could omit the un-changed dimensions in offsets and sizes, so llvm::zip_equal would fail in this case.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D139815

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 4c772c2fb11db..4afbaee3b3e60 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2971,12 +2971,16 @@ class StridedSliceConstantMaskFolder final
     SmallVector<int64_t, 4> sliceMaskDimSizes;
     sliceMaskDimSizes.reserve(maskDimSizes.size());
     for (auto [maskDimSize, sliceOffset, sliceSize] :
-         llvm::zip_equal(maskDimSizes, sliceOffsets, sliceSizes)) {
+         llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
       int64_t sliceMaskDimSize = std::max(
           static_cast<int64_t>(0),
           std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
       sliceMaskDimSizes.push_back(sliceMaskDimSize);
     }
+    // Add unchanged dimensions.
+    if (sliceMaskDimSizes.size() < maskDimSizes.size())
+      for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
+        sliceMaskDimSizes.push_back(maskDimSizes[i]);
     // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
     // region is a conjunction of mask dim intervals).
     if (llvm::is_contained(sliceMaskDimSizes, 0))

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 2addc78f96dbc..ebadecd11e64c 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2083,3 +2083,15 @@ func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> {
   %1 = vector.extract %0[0, 0, 31] : vector<1x1x32x1xf32>
   return %1: vector<1xf32>
 }
+
+// -----
+// CHECK-LABEL: func.func @extract_strided_slice_of_constant_mask
+func.func @extract_strided_slice_of_constant_mask() -> vector<5x7xi1>{
+  //  CHECK-NEXT:   %[[RES:.*]] = vector.constant_mask [5, 4] : vector<5x7xi1>
+  //  CHECK-NEXT:   return %[[RES]] : vector<5x7xi1>
+  %c4 = arith.constant 4 : index
+  %c10 = arith.constant 10 : index
+  %mask = vector.create_mask %c10, %c4 : vector<12x7xi1>
+  %res = vector.extract_strided_slice %mask {offsets = [3], sizes = [5], strides = [1]} : vector<12x7xi1> to vector<5x7xi1>
+  return %res : vector<5x7xi1>
+}


        


More information about the Mlir-commits mailing list