[Mlir-commits] [mlir] [MLIR] Fix VectorEmulateNarrowType constant op mask bug (PR #116064)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 13 07:16:50 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: lialan (lialan)

<details>
<summary>Changes</summary>

This commit adds support for handling mask constants generated by the `arith.constant` op in the `VectorEmulateNarrowType` pattern. Previously, this pattern would not match due to the lack of mask constant handling in `getCompressedMaskOp`.

The changes include:

1. Updating `getCompressedMaskOp` to recognize and handle `arith.constant` ops as mask value sources.

2. Handling cases where the mask is not aligned with the emulated load width. The compressed mask is adjusted to account for the offset.

Limitations:
- The arith.constant op can only have 1-dimensional constant values.

Resolves: #<!-- -->115742

---
Full diff: https://github.com/llvm/llvm-project/pull/116064.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+44-2) 
- (modified) mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir (+171) 
- (modified) mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir (+24) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index eb4ce24548e603..f2e9ae18d3371c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -70,7 +70,9 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
   Operation *maskOp = mask.getDefiningOp();
   SmallVector<vector::ExtractOp, 2> extractOps;
   // Finding the mask creation operation.
-  while (maskOp && !isa<vector::CreateMaskOp, vector::ConstantMaskOp>(maskOp)) {
+  while (maskOp &&
+         !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
+             maskOp)) {
     if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
       maskOp = extractOp.getVector().getDefiningOp();
       extractOps.push_back(extractOp);
@@ -78,7 +80,8 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
   }
   auto createMaskOp = dyn_cast_or_null<vector::CreateMaskOp>(maskOp);
   auto constantMaskOp = dyn_cast_or_null<vector::ConstantMaskOp>(maskOp);
-  if (!createMaskOp && !constantMaskOp)
+  auto constantOp = dyn_cast_or_null<arith::ConstantOp>(maskOp);
+  if (!createMaskOp && !constantMaskOp && !constantOp)
     return failure();
 
   // Computing the "compressed" mask. All the emulation logic (i.e. computing
@@ -129,6 +132,45 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
       auto denseAttr = DenseElementsAttr::get(newMaskType, newMaskValues);
       newMask = rewriter.create<arith::ConstantOp>(loc, newMaskType, denseAttr);
     }
+  } else if (constantOp) {
+    assert(shape.size() == 1 && "expected 1-D mask");
+    // Rearrange the original mask values to cover the whole potential loading
+    // region. For example, in the case of using byte-size for emulation, given
+    // the following mask:
+    //
+    //   %mask = vector.constant_mask [0, 1, 0, 1, 0, 0] : vector<6xi2>
+    //
+    // with front offset of 1, the mask will be padded zeros in the front and
+    // back so that its length is multiple of `scale` (and the total coverage
+    // size is mulitiple of bytes):
+    //   %new_mask = vector.constant_mask [0, 0, 1, 0, 1, 0, 0, 0] :
+    //   vector<8xi2>
+    //
+    // The %new_mask is now aligned with the effective loading area and can now
+    // be compressed.
+    SmallVector<bool> maskValues(intraDataOffset, false);
+    if (auto denseAttr =
+            mlir::dyn_cast<DenseIntElementsAttr>(constantOp.getValue())) {
+      for (auto value : denseAttr.getValues<bool>()) {
+        maskValues.push_back(value);
+      }
+      while (maskValues.size() < numElements * scale) {
+        maskValues.push_back(false);
+      }
+    } else {
+      return failure();
+    }
+    // Compressing by combining every `scale` elements:
+    SmallVector<bool> compressedMaskValues;
+    for (size_t i = 0; i < maskValues.size(); i += scale) {
+      bool combinedValue = false;
+      for (int j = 0; j < scale; ++j) {
+        combinedValue |= maskValues[i + j];
+      }
+      compressedMaskValues.push_back(combinedValue);
+    }
+    newMask = rewriter.create<arith::ConstantOp>(
+        loc, DenseElementsAttr::get(newMaskType, compressedMaskValues));
   }
 
   while (!extractOps.empty()) {
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 7ed75ff7f1579c..359162d76219f4 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -249,3 +249,174 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>,
 // CHECK: %[[IN8:.+]] = vector.insert %[[EX8]], %[[IN7]] [1] : i2 into vector<3xi2>
 // CHECK: %[[EX9:.+]] = vector.extract %[[SELECT]][%[[INCIDX2]]] : i2 from vector<8xi2>
 // CHECK: %[[IN9:.+]] = vector.insert %[[EX9]], %[[IN8]] [2] : i2 into vector<3xi2>
+
+// -----
+
+func.func @vector_store_i2_const(%arg0: vector<3xi2>) {
+    %0 = memref.alloc() : memref<3x3xi2>
+    %c0 = arith.constant 0 : index
+    %c2 = arith.constant 2 : index
+    vector.store %arg0, %0[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
+    return
+}
+
+// in this example, emit 2 atomic stores, with the first storing 1 element and the second storing 2 elements.
+// CHECK: func @vector_store_i2_const(
+// CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+
+// atomic store of the first byte
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, true, true]> : vector<4xi1>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : vector<3xi2> to vector<2xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
+// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xi2> into vector<4xi2>
+// CHECK: %[[ATOMIC_RMW:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[C1]]] : memref<3xi8> {
+// CHECK: %[[ARG:.+]]: i8):
+// CHECK: %[[FROM_ELEM:.+]] = vector.from_elements %[[ARG]] : vector<1xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[FROM_ELEM]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[BITCAST]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT2]] : i8
+
+// atomic store of the second byte
+// CHECK: %[[ADDI:.+]] = arith.addi %[[C1]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [2], sizes = [1], strides = [1]} : vector<3xi2> to vector<1xi2>
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT2]], %[[CST0]]
+// CHECK-SAME: {offsets = [0], strides = [1]} : vector<1xi2> into vector<4xi2>
+// CHECK: %[[ATOMIC_RMW2:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[ADDI]]] : memref<3xi8> {
+// CHECK: %[[ARG2:.+]]: i8):
+// CHECK: %[[FROM_ELEM2:.+]] = vector.from_elements %[[ARG2]] : vector<1xi8>
+// CHECK: %[[BITCAST3:.+]] = vector.bitcast %[[FROM_ELEM2]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[BITCAST3]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST4:.+]] = vector.bitcast %[[SELECT2]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST4]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT3]] : i8
+
+// -----
+
+func.func @vector_store_i8_2(%arg0: vector<7xi2>) {
+    %0 = memref.alloc() : memref<3x7xi2>
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    vector.store %arg0, %0[%c1, %c0] :memref<3x7xi2>, vector<7xi2>
+    return
+}
+
+// in this example, emit 2 atomic stores and 1 non-atomic store
+
+// CHECK: func @vector_store_i8_2(
+// CHECK-SAME: %[[ARG0:.+]]: vector<7xi2>)
+// CHECK: %[[ALLOC]] = memref.alloc() : memref<6xi8>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, false, true]> : vector<4xi1>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+
+// first atomic store
+// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [0], sizes = [1], strides = [1]} : vector<7xi2> to vector<1xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
+// CHECK-SAME: {offsets = [3], strides = [1]} : vector<1xi2> into vector<4xi2>
+// CHECK: %[[ATOMIC_RMW:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[C1]]] : memref<6xi8> {
+// CHECK: %[[ARG:.+]]: i8):
+// CHECK: %[[FROM_ELEM:.+]] = vector.from_elements %[[ARG]] : vector<1xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[FROM_ELEM]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[BITCAST]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT2]] : i8
+
+// non atomic store part
+// CHECK: %[[ADDR:.+]] = arith.addi %[[C1]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [1], sizes = [4], strides = [1]} : vector<7xi2> to vector<4xi2>
+// CHECK: %[[BITCAST3:.+]] = vector.bitcast %[[EXTRACT2]] : vector<4xi2> to vector<1xi8>
+// CHECK: vector.store %[[BITCAST3]], %[[ALLOC]][%[[ADDR]]] : memref<6xi8>, vector<1xi8>
+
+// second atomic store
+// CHECK: %[[ADDR2:.+]] = arith.addi %[[ADDR]], %[[C1]] : index
+// CHECK: %[[EXTRACT3:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [5], sizes = [2], strides = [1]} : vector<7xi2> to vector<2xi2>
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT3]], %[[CST0]]
+// CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xi2> into vector<4xi2>
+// CHECK: %[[ATOMIC_RMW2:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[ADDR2]]] : memref<6xi8> {
+// CHECK: %[[ARG2:.+]]: i8):
+// CHECK: %[[FROM_ELEM2:.+]] = vector.from_elements %[[ARG2]] : vector<1xi8>
+// CHECK: %[[BITCAST4:.+]] = vector.bitcast %[[FROM_ELEM2]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[BITCAST4]] :
+// CHECK-SAME: vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST5:.+]] = vector.bitcast %[[SELECT2]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT4:.+]] = vector.extract %[[BITCAST5]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT4]] : i8    
+
+// -----
+
+func.func @vector_store_i2_single_atomic(%arg0: vector<1xi2>) {
+    %0 = memref.alloc() : memref<4x1xi2>
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    vector.store %arg0, %0[%c1, %c0] :memref<4x1xi2>, vector<1xi2>
+    return
+}
+
+// in this example, only emit 1 atomic store
+// CHECK: func @vector_store_i2_single_atomic(
+// CHECK-SAME: %[[ARG0:.+]]: vector<1xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1xi8>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, true, false, false]> : vector<4xi1>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[ARG0]], %[[CST0]]
+// CHECK-SAME: {offsets = [1], strides = [1]} : vector<1xi2> into vector<4xi2>
+
+// CHECK: %[[ATOMIC_RMW:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[C0]]] : memref<1xi8> {
+// CHECK: %[[ARG:.+]]: i8):
+// CHECK: %[[FROM_ELEM:.+]] = vector.from_elements %[[ARG]] : vector<1xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[FROM_ELEM]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[BITCAST]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT2]] : i8
+
+// -----
+
+func.func @vector_maskedload_i4_constant_mask_unaligned(%passthru: vector<5xi2>) -> vector<5xi2> {
+  %0 = memref.alloc() : memref<3x5xi2>
+  %mask = arith.constant dense<[false, true, true, true, false]> : vector<5xi1>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %1 = vector.maskedload %0[%c1, %c0], %mask, %passthru :
+    memref<3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2>
+  return %1 : vector<5xi2>
+}
+
+// CHECK: func @vector_maskedload_i4_constant_mask_unaligned(
+// CHECK-SAME: %[[PTH:.+]]: vector<5xi2>) -> vector<5xi2>
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<4xi8>
+// CHECK: %[[MASK:.+]] = arith.constant dense<[false, true, true, true, false]> : vector<5xi1>
+
+// CHECK: %[[CST0:.+]] = arith.constant dense<true> : vector<2xi1>
+// CHECK: %[[CST1:.+]] = arith.constant dense<0> : vector<8xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[PTH]], %[[CST1]]
+// CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi2> into vector<8xi2>
+
+// Emulated masked load from alloc:
+// CHECK: %[[BCAST:.+]] = vector.bitcast %[[INSERT]] : vector<8xi2> to vector<2xi8>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[MASKLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[C1]]], %[[CST0]], %[[BCAST]]
+// CHECK: %[[BCAST2:.+]] = vector.bitcast %[[MASKLOAD]] : vector<2xi8> to vector<8xi2>
+
+// Select from emulated loaded vector and passthru vector:
+// TODO: fold this part if possible.
+// CHECK: %[[CST2:.+]] = arith.constant dense<false> : vector<8xi1>
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[BCAST2]], %[[CST2]]
+// CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi1> into vector<8xi1>
+// CHECK: %[[SELECT:.+]] = arith.select %[[INSERT2]], %[[BCAST2]], %[[INSERT]] : vector<8xi1>, vector<8xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SELECT]]
+// CHECK-SAME: {offsets = [1], sizes = [5], strides = [1]} : vector<8xi2> to vector<5xi2>
+// CHECK: return %[[EXTRACT]] : vector<5xi2>
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index 034bd47f6163e6..19edc9ddcaf2b4 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -624,3 +624,27 @@ func.func @vector_maskedstore_i4_constant_mask(
 // CHECK32:           %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4>
 // CHECK32:           %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<1xi32>
 // CHECK32:           vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<3xi32>, vector<1xi1>, vector<1xi32>
+
+// -----
+
+func.func @vector_maskedload_i4_arith_constant(%passthru: vector<8xi4>) -> vector<8xi4> {
+  %0 = memref.alloc() : memref<3x8xi4>
+  %cst = arith.constant dense<0> : vector<8xi4>
+  %mask = arith.constant dense<[false, true, true, true, true, false, false, false]> : vector<8xi1>
+  %c0 = arith.constant 0 : index
+  %1 = vector.maskedload %0[%c0, %c0], %mask, %passthru :
+    memref<3x8xi4>, vector<8xi1>, vector<8xi4> into vector<8xi4>
+  return %1 : vector<8xi4>
+}
+
+// CHECK: func @vector_maskedload_i4_arith_constant(
+// CHECK-SAME:   %[[PASSTHRU:[a-zA-Z0-9]+]]: vector<8xi4>) -> vector<8xi4> {
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<24xi8>
+// CHECK: %[[MASK:.+]] = arith.constant dense<[false, true, true, true, true, false, false, false]> : vector<8xi1>
+// CHECK: %[[CST:.+]] = arith.constant dense<[true, true, true, false]> : vector<4xi1>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[PASSTHRU]] : vector<8xi4> to vector<4xi8>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[C0]], %[[C0]]], %[[MASK]], %[[BITCAST]] : memref<24xi8>, vector<8xi1>, vector<4xi8> into vector<4xi8>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[BITCAST2]], %[[PASSTHRU]] : vector<4xi1>, vector<8xi4>
+// CHECK: return %[[SELECT]] : vector<8xi4>

``````````

</details>


https://github.com/llvm/llvm-project/pull/116064


More information about the Mlir-commits mailing list