[Mlir-commits] [mlir] [MLIR][XeGPU] Add lowering from transfer_read/transfer_write to load_gather/store_scatter (PR #152429)

Adam Siemieniuk llvmlistbot at llvm.org
Mon Aug 11 03:14:01 PDT 2025


================
@@ -155,6 +152,340 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc,
   return ndDesc;
 }
 
+static LogicalResult
+extraCheckForScatteredLoadStore(VectorTransferOpInterface xferOp,
+                                PatternRewriter &rewriter) {
+  // 1. it must be inbound access by checking in_bounds attributes, like
+  // {in_bounds = [false, true]}
+  if (xferOp.hasOutOfBoundsDim())
+    return rewriter.notifyMatchFailure(xferOp,
+                                       "Out-of-bounds access is not supported "
+                                       "for scatter load/store lowering");
+  // 2. if the memref has static shape, its lower rank must exactly match with
+  // vector shape.
+  if (auto memrefType = dyn_cast<MemRefType>(xferOp.getShapedType())) {
+    if (memrefType.hasStaticShape()) {
+      ArrayRef<int64_t> memrefShape = memrefType.getShape();
+      ArrayRef<int64_t> vectorShape = xferOp.getVectorType().getShape();
+      size_t memrefRank = memrefShape.size();
+      size_t vectorRank = vectorShape.size();
+      if (vectorRank > memrefRank)
+        return rewriter.notifyMatchFailure(
+            xferOp, "Vector rank cannot exceed memref rank");
+      // Compare the last vectorRank dimensions of memref with vector shape
+      for (size_t i = 0; i < vectorRank; ++i) {
+        if (memrefShape[memrefRank - vectorRank + i] <= vectorShape[i])
+          return rewriter.notifyMatchFailure(
+              xferOp, "Memref lower dimensions must match vector shape");
+      }
+    }
+  }
+  return success();
+}
+
+static LogicalResult adjustStridesForPermutation(
+    Operation *op, PatternRewriter &rewriter, MemRefType memrefType,
+    AffineMap permMap, VectorType vecType, SmallVectorImpl<Value> &strides) {
+  unsigned vecRank;
+  unsigned memrefRank = memrefType.getRank();
+
+  if (permMap.isMinorIdentity())
+    return success();
+  vecRank = vecType.getRank();
+  // Only adjust the last vecRank strides according to the permutation
+  ArrayRef<Value> relevantStrides = ArrayRef<Value>(strides).take_back(vecRank);
+  SmallVector<Value> adjustedStrides(vecRank);
+  // For each output dimension in the permutation map, find which input dim it
+  // refers to, and assign the corresponding stride.
+  for (unsigned outIdx = 0; outIdx < vecRank; ++outIdx) {
+    AffineExpr expr = permMap.getResult(outIdx);
+    auto dimExpr = dyn_cast<AffineDimExpr>(expr);
+    if (!dimExpr) {
+      return rewriter.notifyMatchFailure(op, "Unsupported permutation expr");
+    }
+    unsigned pos = dimExpr.getPosition();
+    // Map permutation to the relevant strides (innermost dims)
+    if (pos < memrefRank - vecRank) {
+      return rewriter.notifyMatchFailure(op, "Permutation out of bounds");
+    }
+    // The stride for output dimension outIdx is the stride of input dimension
+    // pos
+    adjustedStrides[outIdx] = relevantStrides[pos - (memrefRank - vecRank)];
+  }
+  // Replace the last vecRank strides with the adjusted ones
+  for (unsigned i = 0; i < vecRank; ++i)
+    strides[memrefRank - vecRank + i] = adjustedStrides[i];
+
+  return success();
+}
+
+SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
+                                  PatternRewriter &rewriter) {
+  SmallVector<Value> strides;
+  Value baseMemref = xferOp.getBase();
+  AffineMap permMap = xferOp.getPermutationMap();
+  VectorType vectorType = xferOp.getVectorType();
+  MemRefType memrefType = llvm::cast<MemRefType>(baseMemref.getType());
+
+  Location loc = xferOp.getLoc();
+  if (memrefType.hasStaticShape()) {
+    int64_t offset;
+    SmallVector<int64_t> intStrides;
+    if (failed(memrefType.getStridesAndOffset(intStrides, offset))) {
+      return {};
+    }
+    // Wrap static strides as MLIR values
+    for (int64_t s : intStrides)
+      strides.push_back(rewriter.create<arith::ConstantIndexOp>(loc, s));
+  } else {
+    // For dynamic shape memref, use memref.extract_strided_metadata to get
+    // stride values
+    unsigned rank = memrefType.getRank();
+    Type indexType = rewriter.getIndexType();
+
+    // Result types: [base_memref, offset, stride0, stride1, ..., strideN-1,
+    // size0, size1, ..., sizeN-1]
+    SmallVector<Type> resultTypes;
+    resultTypes.push_back(MemRefType::get(
+        {}, memrefType.getElementType())); // base memref (unranked)
+    resultTypes.push_back(indexType);      // offset
+    for (unsigned i = 0; i < rank; ++i) {
+      resultTypes.push_back(indexType); // strides
+    }
+    for (unsigned i = 0; i < rank; ++i) {
+      resultTypes.push_back(indexType); // sizes
+    }
+
+    auto meta = rewriter.create<memref::ExtractStridedMetadataOp>(
+        loc, resultTypes, baseMemref);
+    strides.append(meta.getStrides().begin(), meta.getStrides().end());
+  }
+  // Adjust strides according to the permutation map (e.g., for transpose)
+  if (failed(adjustStridesForPermutation(xferOp, rewriter, memrefType, permMap,
+                                         vectorType, strides))) {
+    return {};
+  }
+  return strides;
+}
+
+// This function compute the vectors of localOffsets for scattered load/stores.
+// It is used in the lowering of vector.transfer_read/write to
+// load_gather/store_scatter Example:
+//   %0 = vector.transfer_read %expand_shape[%block_id_y, %c0, %c0, %c0, %c0],
+//               %cst {in_bounds = [true, true, true, true]}>} :
+//               memref<8x4x2x6x32xbf16>, vector<4x2x6x32xbf16>
+//
+//   %6 = vector.step: vector<4xindex>
+//   %7 = vector.step: vector<2xindex>
+//   %8 = vector.step: vector<6xindex>
+//   %9 = vector.step: vector<32xindex>
+//   %10 = arith.mul %6, 384
+//   %11 = arith.mul %7, 192
+//   %12 = arith.mul %8, 32
+//   %13 = arith.mul %9, 1
+//   %14 = vector.shape_cast %10: vector<4xindex> -> vector<4x1x1x1xbf16>
+//   %15 = vector.shape_cast %11: vector<2xindex> -> vector<1x2x1x1xbf16>
+//   %16 = vector.shape_cast %12: vector<6xindex> -> vector<1x1x6x1xbf16>
+//   %17 = vector.shape_cast %13: vector<32xindex> -> vector<1x1x1x32xbf16>
+//   %18 = vector.broadcast %14: vector<4x1x1x1xbf16> -> vector<4x2x6x32xindex>
+//   %19 = vector.broadcast %15: vector<1x2x1x1xbf16> -> vector<4x2x6x32xindex>
+//   %20 = vector.broadcast %16: vector<1x1x6x1xbf16> -> vector<4x2x6x32xindex>
+//   %21 = vector.broadcast %17: vector<1x1x1x32xbf16> -> vector<4x2x6x32xindex>
+//   %22 = arith.add %18, %19
+//   %23 = arith.add %20, %21
+//   %local_offsets = arith.add %22, %23
+//   %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map
+//   %offsets =  orig_offset + local_offsets
+static Value computeOffsets(VectorTransferOpInterface xferOp,
+                            PatternRewriter &rewriter,
+                            ArrayRef<Value> strides) {
+  Location loc = xferOp.getLoc();
+  VectorType vectorType = xferOp.getVectorType();
+  SmallVector<Value> indices(xferOp.getIndices().begin(),
+                             xferOp.getIndices().end());
+  ArrayRef<int64_t> vectorShape = vectorType.getShape();
+
+  // Step 1: Create vector.step operations for each dimension
+  SmallVector<Value> stepVectors;
+  for (int64_t dim : vectorShape) {
+    auto stepType = VectorType::get({dim}, rewriter.getIndexType());
+    auto stepOp = rewriter.create<vector::StepOp>(loc, stepType);
+    stepVectors.push_back(stepOp);
+  }
+
+  // Step 2: Multiply step vectors by corresponding strides
+  size_t memrefRank = strides.size();
+  size_t vectorRank = vectorShape.size();
+  SmallVector<Value> strideMultiplied;
+  for (size_t i = 0; i < vectorRank; ++i) {
+    size_t memrefDim = memrefRank - vectorRank + i;
+    Value strideValue = strides[memrefDim];
+    auto mulType = llvm::cast<VectorType>(stepVectors[i].getType());
+    auto mulOp = rewriter.create<arith::MulIOp>(
+        loc, stepVectors[i],
+        rewriter.create<vector::BroadcastOp>(loc, mulType, strideValue));
+    strideMultiplied.push_back(mulOp);
+  }
+
+  // Step 3: Shape cast each multiplied vector to add singleton dimensions
+  SmallVector<Value> shapeCasted;
+  for (size_t i = 0; i < vectorRank; ++i) {
+    SmallVector<int64_t> newShape(vectorRank, 1);
+    newShape[i] = vectorShape[i];
+    auto newType = VectorType::get(newShape, rewriter.getIndexType());
+    auto castOp =
+        rewriter.create<vector::ShapeCastOp>(loc, newType, strideMultiplied[i]);
+    shapeCasted.push_back(castOp);
+  }
+
+  // Step 4: Broadcast each shape-casted vector to full vector shape
+  SmallVector<Value> broadcasted;
+  auto fullIndexVectorType =
+      VectorType::get(vectorShape, rewriter.getIndexType());
+  for (Value shapeCastVal : shapeCasted) {
+    auto broadcastOp = rewriter.create<vector::BroadcastOp>(
+        loc, fullIndexVectorType, shapeCastVal);
+    broadcasted.push_back(broadcastOp);
+  }
+
+  // Step 5: Add all broadcasted vectors together to compute local offsets
+  Value localOffsets = broadcasted[0];
+  for (size_t i = 1; i < broadcasted.size(); ++i) {
+    localOffsets =
+        rewriter.create<arith::AddIOp>(loc, localOffsets, broadcasted[i]);
+  }
+
+  // Step 6: Compute base offset from transfer read indices
+  Value baseOffset = nullptr;
+  if (!indices.empty()) {
+    baseOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    for (size_t i = 0; i < indices.size(); ++i) {
+      Value strideVal = strides[i];
+      Value offsetContrib =
+          rewriter.create<arith::MulIOp>(loc, indices[i], strideVal);
+      baseOffset =
+          rewriter.create<arith::AddIOp>(loc, baseOffset, offsetContrib);
+    }
+    // Broadcast base offset to match vector shape
+    Value bcastBase = rewriter.create<vector::BroadcastOp>(
+        loc, fullIndexVectorType, baseOffset);
+    localOffsets = rewriter.create<arith::AddIOp>(loc, bcastBase, localOffsets);
+  }
+  return localOffsets;
+}
+
+// Collapse memref shape to 1D
+static Value collapseMemrefTo1D(VectorTransferOpInterface xferOp,
+                                PatternRewriter &rewriter) {
+  Location loc = xferOp.getLoc();
+
+  Value baseMemref = xferOp.getBase();
+  MemRefType memrefType = llvm::cast<MemRefType>(baseMemref.getType());
+  Type elementType = memrefType.getElementType();
+
+  // Compute the total number of elements in the memref
+  int64_t totalElements = 1;
+  bool hasDynamicDim = false;
----------------
adam-smnk wrote:

nit: looks like you could assign `ShapedType::kDynamic` to `totalElements` directly and skip the separate flag

https://github.com/llvm/llvm-project/pull/152429


More information about the Mlir-commits mailing list