[Mlir-commits] [mlir] [mlir][VectorOps] Extend vector.constant_mask to support 'all true' scalable dims (PR #66638)

Benjamin Maxwell llvmlistbot at llvm.org
Wed Sep 20 04:11:43 PDT 2023


================
@@ -115,43 +114,41 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
       bool value = cast<IntegerAttr>(dimSizes[0]).getInt() == 1;
       rewriter.replaceOpWithNewOp<arith::ConstantOp>(
           op, dstType,
-          DenseIntElementsAttr::get(
-              VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()),
-              ArrayRef<bool>{value}));
+          DenseIntElementsAttr::get(VectorType::get({}, rewriter.getI1Type()),
+                                    value));
       return success();
     }
 
-    // Scalable constant masks can only be lowered for the "none set" case.
-    if (cast<VectorType>(dstType).isScalable()) {
-      rewriter.replaceOpWithNewOp<arith::ConstantOp>(
-          op, DenseElementsAttr::get(dstType, false));
-      return success();
-    }
-
-    int64_t trueDim = std::min(dstType.getDimSize(0),
-                               cast<IntegerAttr>(dimSizes[0]).getInt());
+    int64_t trueDimSize = cast<IntegerAttr>(dimSizes[0]).getInt();
 
     if (rank == 1) {
-      // Express constant 1-D case in explicit vector form:
-      //   [T,..,T,F,..,F].
-      SmallVector<bool> values(dstType.getDimSize(0));
-      for (int64_t d = 0; d < trueDim; d++)
-        values[d] = true;
-      rewriter.replaceOpWithNewOp<arith::ConstantOp>(
-          op, dstType, rewriter.getBoolVectorAttr(values));
+      if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(0)) {
+        // Use constant splat for 'all set' or 'none set' dims.
+        // This produces correct code for scalable dimensions.
----------------
MacDue wrote:

Yep, it becomes a constant splat (noted this in the comment now)

https://github.com/llvm/llvm-project/pull/66638


More information about the Mlir-commits mailing list