[Mlir-commits] [mlir] [mlir][VectorOps] Extend vector.constant_mask to support 'all true' scalable dims (PR #66638)

Andrzej WarzyƄski llvmlistbot at llvm.org
Wed Sep 20 01:46:11 PDT 2023


================
@@ -115,43 +114,41 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
       bool value = cast<IntegerAttr>(dimSizes[0]).getInt() == 1;
       rewriter.replaceOpWithNewOp<arith::ConstantOp>(
           op, dstType,
-          DenseIntElementsAttr::get(
-              VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()),
-              ArrayRef<bool>{value}));
+          DenseIntElementsAttr::get(VectorType::get({}, rewriter.getI1Type()),
+                                    value));
       return success();
     }
 
-    // Scalable constant masks can only be lowered for the "none set" case.
-    if (cast<VectorType>(dstType).isScalable()) {
-      rewriter.replaceOpWithNewOp<arith::ConstantOp>(
-          op, DenseElementsAttr::get(dstType, false));
-      return success();
-    }
-
-    int64_t trueDim = std::min(dstType.getDimSize(0),
-                               cast<IntegerAttr>(dimSizes[0]).getInt());
+    int64_t trueDimSize = cast<IntegerAttr>(dimSizes[0]).getInt();
 
     if (rank == 1) {
-      // Express constant 1-D case in explicit vector form:
-      //   [T,..,T,F,..,F].
-      SmallVector<bool> values(dstType.getDimSize(0));
-      for (int64_t d = 0; d < trueDim; d++)
-        values[d] = true;
-      rewriter.replaceOpWithNewOp<arith::ConstantOp>(
-          op, dstType, rewriter.getBoolVectorAttr(values));
+      if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(0)) {
+        // Use constant splat for 'all set' or 'none set' dims.
+        // This produces correct code for scalable dimensions.
+        rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+            op, DenseElementsAttr::get(dstType, trueDimSize != 0));
+      } else {
+        // Express constant 1-D case in explicit vector form:
+        //   [T,..,T,F,..,F].
+        SmallVector<bool> values(dstType.getDimSize(0));
+        for (int64_t d = 0; d < trueDimSize; d++)
+          values[d] = true;
+        rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+            op, dstType, rewriter.getBoolVectorAttr(values));
+      }
       return success();
     }
 
-    VectorType lowType =
-        VectorType::get(dstType.getShape().drop_front(), eltType);
-    SmallVector<int64_t> newDimSizes;
-    for (int64_t r = 1; r < rank; r++)
-      newDimSizes.push_back(cast<IntegerAttr>(dimSizes[r]).getInt());
+    if (dstType.getScalableDims().front())
+      return rewriter.notifyMatchFailure(
+          op, "Cannot unroll leading scalable dim in dstType");
----------------
banach-space wrote:

We should be able to test for this in invalid.mlir, right?

https://github.com/llvm/llvm-project/pull/66638


More information about the Mlir-commits mailing list