[Mlir-commits] [mlir] 181d960 - [mlir][vectorize] Support affine.apply in SuperVectorize (#77968)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Feb 16 03:30:21 PST 2024


Author: Hsiangkai Wang
Date: 2024-02-16T11:30:18Z
New Revision: 181d9602f35c476201ecc72c26c00042ad949544

URL: https://github.com/llvm/llvm-project/commit/181d9602f35c476201ecc72c26c00042ad949544
DIFF: https://github.com/llvm/llvm-project/commit/181d9602f35c476201ecc72c26c00042ad949544.diff

LOG: [mlir][vectorize] Support affine.apply in SuperVectorize (#77968)

We have no need to vectorize affine.apply inside the vectorizing loop.
However, we still need to generate it in the original scalar form. We
have to replace all its operands with the generated scalar operands in
the vectorizing loop, e.g., induction variables.

Added: 
    mlir/test/Dialect/Affine/SuperVectorize/vectorize_affine_apply.mlir

Modified: 
    mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
index 6b7a157925fae1..46c7871f40232f 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
@@ -711,18 +711,16 @@ struct VectorizationState {
                                          BlockArgument replacement);
 
   /// Registers the scalar replacement of a scalar value. 'replacement' must be
-  /// scalar. Both values must be block arguments. Operation results should be
-  /// replaced using the 'registerOp*' utilitites.
+  /// scalar.
   ///
   /// This utility is used to register the replacement of block arguments
-  /// that are within the loop to be vectorized and will continue being scalar
-  /// within the vector loop.
+  /// or affine.apply results that are within the loop be vectorized and will
+  /// continue being scalar within the vector loop.
   ///
   /// Example:
   ///   * 'replaced': induction variable of a loop to be vectorized.
   ///   * 'replacement': new induction variable in the new vector loop.
-  void registerValueScalarReplacement(BlockArgument replaced,
-                                      BlockArgument replacement);
+  void registerValueScalarReplacement(Value replaced, Value replacement);
 
   /// Registers the scalar replacement of a scalar result returned from a
   /// reduction loop. 'replacement' must be scalar.
@@ -772,7 +770,6 @@ struct VectorizationState {
   /// Internal implementation to map input scalar values to new vector or scalar
   /// values.
   void registerValueVectorReplacementImpl(Value replaced, Value replacement);
-  void registerValueScalarReplacementImpl(Value replaced, Value replacement);
 };
 
 } // namespace
@@ -844,19 +841,22 @@ void VectorizationState::registerValueVectorReplacementImpl(Value replaced,
 }
 
 /// Registers the scalar replacement of a scalar value. 'replacement' must be
-/// scalar. Both values must be block arguments. Operation results should be
-/// replaced using the 'registerOp*' utilitites.
+/// scalar.
 ///
 /// This utility is used to register the replacement of block arguments
-/// that are within the loop to be vectorized and will continue being scalar
-/// within the vector loop.
+/// or affine.apply results that are within the loop be vectorized and will
+/// continue being scalar within the vector loop.
 ///
 /// Example:
 ///   * 'replaced': induction variable of a loop to be vectorized.
 ///   * 'replacement': new induction variable in the new vector loop.
-void VectorizationState::registerValueScalarReplacement(
-    BlockArgument replaced, BlockArgument replacement) {
-  registerValueScalarReplacementImpl(replaced, replacement);
+void VectorizationState::registerValueScalarReplacement(Value replaced,
+                                                        Value replacement) {
+  assert(!valueScalarReplacement.contains(replaced) &&
+         "Scalar value replacement already registered");
+  assert(!isa<VectorType>(replacement.getType()) &&
+         "Expected scalar type in scalar replacement");
+  valueScalarReplacement.map(replaced, replacement);
 }
 
 /// Registers the scalar replacement of a scalar result returned from a
@@ -879,15 +879,6 @@ void VectorizationState::registerLoopResultScalarReplacement(
   loopResultScalarReplacement[replaced] = replacement;
 }
 
-void VectorizationState::registerValueScalarReplacementImpl(Value replaced,
-                                                            Value replacement) {
-  assert(!valueScalarReplacement.contains(replaced) &&
-         "Scalar value replacement already registered");
-  assert(!isa<VectorType>(replacement.getType()) &&
-         "Expected scalar type in scalar replacement");
-  valueScalarReplacement.map(replaced, replacement);
-}
-
 /// Returns in 'replacedVals' the scalar replacement for values in 'inputVals'.
 void VectorizationState::getScalarValueReplacementsFor(
     ValueRange inputVals, SmallVectorImpl<Value> &replacedVals) {
@@ -978,6 +969,33 @@ static arith::ConstantOp vectorizeConstant(arith::ConstantOp constOp,
   return newConstOp;
 }
 
+/// We have no need to vectorize affine.apply. However, we still need to
+/// generate it and replace the operands with values in valueScalarReplacement.
+static Operation *vectorizeAffineApplyOp(AffineApplyOp applyOp,
+                                         VectorizationState &state) {
+  SmallVector<Value, 8> updatedOperands;
+  for (Value operand : applyOp.getOperands()) {
+    if (state.valueVectorReplacement.contains(operand)) {
+      LLVM_DEBUG(
+          dbgs() << "\n[early-vect]+++++ affine.apply on vector operand\n");
+      return nullptr;
+    } else {
+      Value updatedOperand = state.valueScalarReplacement.lookupOrNull(operand);
+      if (!updatedOperand)
+        updatedOperand = operand;
+      updatedOperands.push_back(updatedOperand);
+    }
+  }
+
+  auto newApplyOp = state.builder.create<AffineApplyOp>(
+      applyOp.getLoc(), applyOp.getAffineMap(), updatedOperands);
+
+  // Register the new affine.apply result.
+  state.registerValueScalarReplacement(applyOp.getResult(),
+                                       newApplyOp.getResult());
+  return newApplyOp;
+}
+
 /// Creates a constant vector filled with the neutral elements of the given
 /// reduction. The scalar type of vector elements will be taken from
 /// `oldOperand`.
@@ -1184,11 +1202,17 @@ static Operation *vectorizeAffineLoad(AffineLoadOp loadOp,
   SmallVector<Value, 8> indices;
   indices.reserve(memRefType.getRank());
   if (loadOp.getAffineMap() !=
-      state.builder.getMultiDimIdentityMap(memRefType.getRank()))
+      state.builder.getMultiDimIdentityMap(memRefType.getRank())) {
+    // Check the operand in loadOp affine map does not come from AffineApplyOp.
+    for (auto op : mapOperands) {
+      if (op.getDefiningOp<AffineApplyOp>())
+        return nullptr;
+    }
     computeMemoryOpIndices(loadOp, loadOp.getAffineMap(), mapOperands, state,
                            indices);
-  else
+  } else {
     indices.append(mapOperands.begin(), mapOperands.end());
+  }
 
   // Compute permutation map using the information of new vector loops.
   auto permutationMap = makePermutationMap(state.builder.getInsertionBlock(),
@@ -1493,6 +1517,8 @@ static Operation *vectorizeOneOperation(Operation *op,
     return vectorizeAffineYieldOp(yieldOp, state);
   if (auto constant = dyn_cast<arith::ConstantOp>(op))
     return vectorizeConstant(constant, state);
+  if (auto applyOp = dyn_cast<AffineApplyOp>(op))
+    return vectorizeAffineApplyOp(applyOp, state);
 
   // Other ops with regions are not supported.
   if (op->getNumRegions() != 0)

diff  --git a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_affine_apply.mlir b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_affine_apply.mlir
new file mode 100644
index 00000000000000..15a7133cf0f65f
--- /dev/null
+++ b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_affine_apply.mlir
@@ -0,0 +1,159 @@
+// RUN: mlir-opt %s -affine-super-vectorize="virtual-vector-size=8 test-fastest-varying=0" -split-input-file | FileCheck %s
+
+// CHECK-DAG: #[[$MAP_ID0:map[0-9a-zA-Z_]*]] = affine_map<(d0) -> (d0 mod 12)>
+// CHECK-DAG: #[[$MAP_ID1:map[0-9a-zA-Z_]*]] = affine_map<(d0) -> (d0 mod 16)>
+
+// CHECK-LABEL: vec_affine_apply
+// CHECK-SAME:  (%[[ARG0:.*]]: memref<8x12x16xf32>, %[[ARG1:.*]]: memref<8x24x48xf32>) {
+func.func @vec_affine_apply(%arg0: memref<8x12x16xf32>, %arg1: memref<8x24x48xf32>) {
+// CHECK:       affine.for %[[ARG2:.*]] = 0 to 8 {
+// CHECK-NEXT:    affine.for %[[ARG3:.*]] = 0 to 24 {
+// CHECK-NEXT:      affine.for %[[ARG4:.*]] = 0 to 48 step 8 {
+// CHECK-NEXT:        %[[S0:.*]] = affine.apply #[[$MAP_ID0]](%[[ARG3]])
+// CHECK-NEXT:        %[[S1:.*]] = affine.apply #[[$MAP_ID1]](%[[ARG4]])
+// CHECK-NEXT:        %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-NEXT:        %[[S2:.*]] = vector.transfer_read %[[ARG0]][%[[ARG2]], %[[S0]], %[[S1]]], %[[CST]] : memref<8x12x16xf32>, vector<8xf32>
+// CHECK-NEXT:        vector.transfer_write %[[S2]], %[[ARG1]][%[[ARG2]], %[[ARG3]], %[[ARG4]]] : vector<8xf32>, memref<8x24x48xf32>
+// CHECK-NEXT:      }
+// CHECK-NEXT:    }
+// CHECK-NEXT:  }
+// CHECK-NEXT:  return
+  affine.for %arg2 = 0 to 8 {
+    affine.for %arg3 = 0 to 24 {
+      affine.for %arg4 = 0 to 48 {
+        %0 = affine.apply affine_map<(d0) -> (d0 mod 12)>(%arg3)
+        %1 = affine.apply affine_map<(d0) -> (d0 mod 16)>(%arg4)
+        %2 = affine.load %arg0[%arg2, %0, %1] : memref<8x12x16xf32>
+        affine.store %2, %arg1[%arg2, %arg3, %arg4] : memref<8x24x48xf32>
+      }
+    }
+  }
+  return
+}
+
+// -----
+
+// CHECK-DAG: #[[$MAP_ID2:map[0-9a-zA-Z_]*]] = affine_map<(d0) -> (d0 mod 16 + 1)>
+
+// CHECK-LABEL: vec_affine_apply_2
+// CHECK-SAME:  (%[[ARG0:.*]]: memref<8x12x16xf32>, %[[ARG1:.*]]: memref<8x24x48xf32>) {
+func.func @vec_affine_apply_2(%arg0: memref<8x12x16xf32>, %arg1: memref<8x24x48xf32>) {
+// CHECK:      affine.for %[[ARG2:.*]] = 0 to 8 {
+// CHECK-NEXT:   affine.for %[[ARG3:.*]] = 0 to 12 {
+// CHECK-NEXT:     affine.for %[[ARG4:.*]] = 0 to 48 step 8 {
+// CHECK-NEXT:       %[[S0:.*]] = affine.apply #[[$MAP_ID2]](%[[ARG4]])
+// CHECK-NEXT:       %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-NEXT:       %[[S1:.*]] = vector.transfer_read %[[ARG0]][%[[ARG2]], %[[ARG3]], %[[S0]]], %[[CST]] : memref<8x12x16xf32>, vector<8xf32>
+// CHECK-NEXT:       vector.transfer_write %[[S1]], %[[ARG1]][%[[ARG2]], %[[ARG3]], %[[ARG4]]] : vector<8xf32>, memref<8x24x48xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:   }
+// CHECK-NEXT: }
+  affine.for %arg2 = 0 to 8 {
+    affine.for %arg3 = 0 to 12 {
+      affine.for %arg4 = 0 to 48 {
+        %1 = affine.apply affine_map<(d0) -> (d0 mod 16 + 1)>(%arg4)
+        %2 = affine.load %arg0[%arg2, %arg3, %1] : memref<8x12x16xf32>
+        affine.store %2, %arg1[%arg2, %arg3, %arg4] : memref<8x24x48xf32>
+      }
+    }
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL: no_vec_affine_apply
+// CHECK-SAME:  (%[[ARG0:.*]]: memref<8x12x16xi32>, %[[ARG1:.*]]: memref<8x24x48xi32>) {
+func.func @no_vec_affine_apply(%arg0: memref<8x12x16xi32>, %arg1: memref<8x24x48xi32>) {
+// CHECK:      affine.for %[[ARG2:.*]] = 0 to 8 {
+// CHECK-NEXT:   affine.for %[[ARG3:.*]] = 0 to 24 {
+// CHECK-NEXT:     affine.for %[[ARG4:.*]] = 0 to 48 {
+// CHECK-NEXT:       %[[S0:.*]] = affine.apply #[[$MAP_ID0]](%[[ARG3]])
+// CHECK-NEXT:       %[[S1:.*]] = affine.apply #[[$MAP_ID1]](%[[ARG4]])
+// CHECK-NEXT:       %[[S2:.*]] = affine.load %[[ARG0]][%[[ARG2]], %[[S0]], %[[S1]]] : memref<8x12x16xi32>
+// CHECK-NEXT:       %[[S3:.*]] = arith.index_cast %[[S2]] : i32 to index
+// CHECK-NEXT:       %[[S4:.*]] = affine.apply #[[$MAP_ID1]](%[[S3]])
+// CHECK-NEXT:       %[[S5:.*]] = arith.index_cast %[[S4]] : index to i32
+// CHECK-NEXT:       affine.store %[[S5]], %[[ARG1]][%[[ARG2]], %[[ARG3]], %[[ARG4]]] : memref<8x24x48xi32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:   }
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+  affine.for %arg2 = 0 to 8 {
+    affine.for %arg3 = 0 to 24 {
+      affine.for %arg4 = 0 to 48 {
+        %0 = affine.apply affine_map<(d0) -> (d0 mod 12)>(%arg3)
+        %1 = affine.apply affine_map<(d0) -> (d0 mod 16)>(%arg4)
+        %2 = affine.load %arg0[%arg2, %0, %1] : memref<8x12x16xi32>
+        %3 = arith.index_cast %2 : i32 to index
+        %4 = affine.apply affine_map<(d0) -> (d0 mod 16)>(%3)
+        %5 = arith.index_cast %4 : index to i32
+        affine.store %5, %arg1[%arg2, %arg3, %arg4] : memref<8x24x48xi32>
+      }
+    }
+  }
+  return
+}
+
+// -----
+
+// CHECK-DAG: #[[$MAP_ID1:map[0-9a-zA-Z_]*]] = affine_map<(d0) -> (d0 mod 16)>
+
+// CHECK-LABEL: affine_map_with_expr
+// CHECK-SAME:  (%[[ARG0:.*]]: memref<8x12x16xf32>, %[[ARG1:.*]]: memref<8x24x48xf32>) {
+func.func @affine_map_with_expr(%arg0: memref<8x12x16xf32>, %arg1: memref<8x24x48xf32>) {
+// CHECK:      affine.for %[[ARG2:.*]] = 0 to 8 {
+// CHECK-NEXT:   affine.for %[[ARG3:.*]] = 0 to 12 {
+// CHECK-NEXT:     affine.for %[[ARG4:.*]] = 0 to 48 {
+// CHECK-NEXT:       %[[S0:.*]] = affine.apply #[[$MAP_ID1]](%[[ARG4]])
+// CHECK-NEXT:       %[[S1:.*]] = affine.load %[[ARG0]][%[[ARG2]], %[[ARG3]], %[[S0]] + 1] : memref<8x12x16xf32>
+// CHECK-NEXT:       affine.store %[[S1]], %[[ARG1]][%[[ARG2]], %[[ARG3]], %[[ARG4]]] : memref<8x24x48xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:   }
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+  affine.for %arg2 = 0 to 8 {
+    affine.for %arg3 = 0 to 12 {
+      affine.for %arg4 = 0 to 48 {
+        %1 = affine.apply affine_map<(d0) -> (d0 mod 16)>(%arg4)
+        %2 = affine.load %arg0[%arg2, %arg3, %1 + 1] : memref<8x12x16xf32>
+        affine.store %2, %arg1[%arg2, %arg3, %arg4] : memref<8x24x48xf32>
+      }
+    }
+  }
+  return
+}
+
+// -----
+
+// CHECK-DAG: #[[$MAP_ID3:map[0-9a-zA-Z_]*]] = affine_map<(d0, d1, d2) -> (d0)>
+// CHECK-DAG: #[[$MAP_ID4:map[0-9a-zA-Z_]*]] = affine_map<(d0, d1, d2) -> (d1)>
+// CHECK-DAG: #[[$MAP_ID5:map[0-9a-zA-Z_]*]] = affine_map<(d0, d1, d2) -> (d2 + 1)>
+// CHECK-DAG: #[[$MAP_ID6:map[0-9a-zA-Z_]*]] = affine_map<(d0, d1, d2) -> (0)>
+
+// CHECK-LABEL: affine_map_with_expr_2
+// CHECK-SAME:  (%[[ARG0:.*]]: memref<8x12x16xf32>, %[[ARG1:.*]]: memref<8x24x48xf32>, %[[I0:.*]]: index) {
+func.func @affine_map_with_expr_2(%arg0: memref<8x12x16xf32>, %arg1: memref<8x24x48xf32>, %i: index) {
+// CHECK:      affine.for %[[ARG3:.*]] = 0 to 8 {
+// CHECK-NEXT:   affine.for %[[ARG4:.*]] = 0 to 12 {
+// CHECK-NEXT:     affine.for %[[ARG5:.*]] = 0 to 48 step 8 {
+// CHECK-NEXT:       %[[S0:.*]] = affine.apply #[[$MAP_ID3]](%[[ARG3]], %[[ARG4]], %[[I0]])
+// CHECK-NEXT:       %[[S1:.*]] = affine.apply #[[$MAP_ID4]](%[[ARG3]], %[[ARG4]], %[[I0]])
+// CHECK-NEXT:       %[[S2:.*]] = affine.apply #[[$MAP_ID5]](%[[ARG3]], %[[ARG4]], %[[I0]])
+// CHECK-NEXT:       %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-NEXT:       %[[S3:.*]] = vector.transfer_read %[[ARG0]][%[[S0]], %[[S1]], %[[S2]]], %[[CST]] {permutation_map = #[[$MAP_ID6]]} : memref<8x12x16xf32>, vector<8xf32>
+// CHECK-NEXT:       vector.transfer_write %[[S3]], %[[ARG1]][%[[ARG3]], %[[ARG4]], %[[ARG5]]] : vector<8xf32>, memref<8x24x48xf32>
+// CHECK-NEXT:     }
+// CHECK-NEXT:   }
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+  affine.for %arg2 = 0 to 8 {
+    affine.for %arg3 = 0 to 12 {
+      affine.for %arg4 = 0 to 48 {
+        %2 = affine.load %arg0[%arg2, %arg3, %i + 1] : memref<8x12x16xf32>
+        affine.store %2, %arg1[%arg2, %arg3, %arg4] : memref<8x24x48xf32>
+      }
+    }
+  }
+  return
+}


        


More information about the Mlir-commits mailing list