[Mlir-commits] [mlir] [mlir][vector][xegpu] Vector to XeGPU conversion pass (PR #107419)
Adam Siemieniuk
llvmlistbot at llvm.org
Thu Sep 5 10:38:39 PDT 2024
================
@@ -0,0 +1,257 @@
+//===- VectorToXeGPU.cpp - Convert vector to XeGPU dialect ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering of vector operations to XeGPU dialect ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/Passes.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+#include <algorithm>
+#include <optional>
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTVECTORTOXEGPU
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+static bool isZeroConstant(Value val) {
+ auto constant = val.getDefiningOp<arith::ConstantOp>();
+ if (!constant)
+ return false;
+
+ return TypeSwitch<Attribute, bool>(constant.getValue())
+ .Case<FloatAttr>(
+ [](auto floatAttr) { return floatAttr.getValue().isZero(); })
+ .Case<IntegerAttr>(
+ [](auto intAttr) { return intAttr.getValue().isZero(); })
+ .Default([](auto) { return false; });
+}
+
+static LogicalResult transferPreconditions(PatternRewriter &rewriter,
+ VectorTransferOpInterface xferOp) {
+ if (xferOp.getMask())
+ return rewriter.notifyMatchFailure(xferOp,
+ "Masked transfer is not supported");
+
+ auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType());
+ if (!srcTy)
+ return rewriter.notifyMatchFailure(xferOp, "Expects memref source");
+ VectorType vecTy = xferOp.getVectorType();
+ unsigned vecRank = vecTy.getRank();
+ if (!(vecRank == 1 || vecRank == 2))
+ return rewriter.notifyMatchFailure(xferOp, "Expects 1D or 2D vector");
+
+ SmallVector<int64_t> strides;
+ int64_t offset;
+ if (failed(getStridesAndOffset(srcTy, strides, offset)) ||
+ strides.back() != 1)
+ return rewriter.notifyMatchFailure(
+ xferOp, "Buffer must be contiguous in the innermost dimension");
+
+ AffineMap map = xferOp.getPermutationMap();
+ if (!map.isProjectedPermutation(/*allowZeroInResults=*/false))
+ return rewriter.notifyMatchFailure(xferOp, "Unsupported permutation map");
+ unsigned numInputDims = map.getNumInputs();
+ for (AffineExpr expr : map.getResults().take_back(vecRank)) {
+ auto dim = dyn_cast<AffineDimExpr>(expr);
+ if (dim.getPosition() < (numInputDims - vecRank))
+ return rewriter.notifyMatchFailure(
+ xferOp, "Only the innermost dimensions can be accessed");
+ }
+
+ return success();
+}
+
+static xegpu::CreateNdDescOp
+createNdDescriptor(PatternRewriter &rewriter, Location loc,
+ xegpu::TensorDescType descType, TypedValue<MemRefType> src,
+ Operation::operand_range offsets) {
+ MemRefType srcTy = src.getType();
+ auto [strides, offset] = getStridesAndOffset(srcTy);
+
+ xegpu::CreateNdDescOp ndDesc;
+ if (srcTy.hasStaticShape()) {
+ ndDesc = rewriter.create<xegpu::CreateNdDescOp>(loc, descType, src,
+ getAsOpFoldResult(offsets));
+ } else {
+ // In case of any dynamic shapes, source's shape and strides have to be
+ // explicitly provided.
+ SmallVector<Value> sourceDims;
+ unsigned srcRank = srcTy.getRank();
+ for (unsigned i = 0; i < srcRank; ++i)
+ sourceDims.push_back(rewriter.create<memref::DimOp>(loc, src, i));
+
+ SmallVector<int64_t> constOffsets;
+ SmallVector<Value> dynOffsets;
+ for (Value offset : offsets) {
+ std::optional<int64_t> staticVal = getConstantIntValue(offset);
+ if (!staticVal)
+ dynOffsets.push_back(offset);
+ constOffsets.push_back(staticVal ? *staticVal : ShapedType::kDynamic);
+ }
+
+ SmallVector<Value> dynShapes;
+ for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) {
+ if (shape == ShapedType::kDynamic)
+ dynShapes.push_back(sourceDims[idx]);
+ }
+
+ // Compute strides in reverse order.
+ SmallVector<Value> dynStrides;
+ Value accStride = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ // Last stride is guaranteed to be static and unit.
+ for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
+ accStride =
+ rewriter.create<arith::MulIOp>(loc, accStride, sourceDims[i + 1]);
+ if (strides[i] == ShapedType::kDynamic)
+ dynStrides.push_back(accStride);
+ }
+ std::reverse(dynStrides.begin(), dynStrides.end());
+
+ ndDesc = rewriter.create<xegpu::CreateNdDescOp>(
+ loc, descType, src, dynOffsets, dynShapes, dynStrides,
+ DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets),
+ DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()),
+ DenseI64ArrayAttr::get(rewriter.getContext(), strides));
+ }
+
+ return ndDesc;
+}
+
+struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
+ using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
+ PatternRewriter &rewriter) const override {
+ Location loc = readOp.getLoc();
+
+ if (failed(transferPreconditions(rewriter, readOp)))
+ return failure();
+
+ bool isOutOfBounds = readOp.hasOutOfBoundsDim();
+ if (isOutOfBounds && !isZeroConstant(readOp.getPadding()))
+ return rewriter.notifyMatchFailure(
+ readOp, "Unsupported non-zero padded out-of-bounds read");
+
+ AffineMap readMap = readOp.getPermutationMap();
+ bool isTransposeLoad = !readMap.isMinorIdentity();
+
+ VectorType vecTy = readOp.getVectorType();
+ Type elementType = vecTy.getElementType();
+ unsigned minTransposeBitWidth = 32;
+ if (isTransposeLoad &&
+ elementType.getIntOrFloatBitWidth() < minTransposeBitWidth)
+ return rewriter.notifyMatchFailure(
+ readOp, "Unsupported data type for tranposition");
+
+ // If load is transposed, get the base shape for the tensor descriptor.
+ SmallVector<int64_t> descShape{vecTy.getShape()};
+ if (isTransposeLoad)
+ std::reverse(descShape.begin(), descShape.end());
+ auto descType = xegpu::TensorDescType::get(
+ descShape, elementType, /*scattered=*/false, /*array_length=*/1,
+ xegpu::MemoryScope::Global,
+ /*boundary_check=*/isOutOfBounds);
+
+ xegpu::CreateNdDescOp ndDesc =
+ createNdDescriptor(rewriter, loc, descType,
+ dyn_cast<TypedValue<MemRefType>>(readOp.getSource()),
+ readOp.getIndices());
+
+ DenseI64ArrayAttr transposeAttr =
+ !isTransposeLoad ? nullptr
+ : DenseI64ArrayAttr::get(rewriter.getContext(),
+ ArrayRef<int64_t>{1, 0});
+ // By default, no specific caching policy is assigned.
+ xegpu::CachePolicyAttr hint = nullptr;
+ auto loadOp = rewriter.create<xegpu::LoadNdOp>(
+ loc, vecTy, ndDesc, /*packed=*/nullptr, transposeAttr,
+ /*l1_hint=*/hint,
+ /*l2_hint=*/hint, /*l3_hint=*/hint);
+ rewriter.replaceOp(readOp, loadOp);
+
+ return success();
+ }
+};
+
+struct TransferWriteLowering
+ : public OpRewritePattern<vector::TransferWriteOp> {
+ using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
+ PatternRewriter &rewriter) const override {
+ Location loc = writeOp.getLoc();
+
+ if (failed(transferPreconditions(rewriter, writeOp)))
+ return failure();
+
+ if (writeOp.hasOutOfBoundsDim())
----------------
adam-smnk wrote:
@chencha3 Actually I'm not sure if out of bound `store_nd` is supported
https://github.com/llvm/llvm-project/pull/107419
More information about the Mlir-commits
mailing list