[Mlir-commits] [mlir] [mlir][amdgpu] Add conversion from arith.scaling_extf to amdgpu (PR #146372)
Tim Gymnich
llvmlistbot at llvm.org
Tue Jul 8 06:01:56 PDT 2025
================
@@ -395,6 +421,249 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
return success();
}
+/// Get the broadcasted / splatted value for a chain of ops.
+static Value getOriginalVectorValue(Value value) {
+ Value current = value;
+ while (Operation *definingOp = current.getDefiningOp()) {
+ bool skipOp = llvm::TypeSwitch<Operation *, bool>(definingOp)
+ .Case<vector::ShapeCastOp>([¤t](auto op) {
+ current = op.getSource();
+ return true;
+ })
+ .Case<vector::BroadcastOp>([¤t](auto op) {
+ current = op.getSource();
+ return false;
+ })
+ .Case<vector::SplatOp>([¤t](auto op) {
+ current = op.getInput();
+ return false;
+ })
+ .Default([](Operation *) { return false; });
+
+ if (!skipOp) {
+ break;
+ }
+ }
+ return current;
+}
+
+LogicalResult
+ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
+ PatternRewriter &rewriter) const {
+ Location loc = op.getLoc();
+ constexpr int64_t opWidth = 2;
+
+ Value in = op.getIn();
+ Value scale = op.getScale();
+ Value out = op.getOut();
+
+ Type f32 = rewriter.getF32Type();
+ Type inType = getElementTypeOrSelf(in);
+ Type scaleType = getElementTypeOrSelf(scale);
+ Type outType = getElementTypeOrSelf(out);
+
+ VectorType outVecType = dyn_cast<VectorType>(out.getType());
+ VectorType scaleVecType = dyn_cast<VectorType>(scale.getType());
+
+ if (outVecType && outVecType.isScalable())
+ return failure();
+
+ Type scaleF32Type =
+ scaleVecType ? VectorType::get(scaleVecType.getShape(), f32) : f32;
+ if (scaleType.getIntOrFloatBitWidth() < 32)
+ scale = rewriter.create<arith::ExtFOp>(loc, scaleF32Type, scale);
+ else if (scaleType.getIntOrFloatBitWidth() > 32)
+ scale = rewriter.create<arith::TruncFOp>(loc, scaleF32Type, scale);
+
+ VectorType extScaleResultType = VectorType::get(opWidth, outType);
+
+ if (!outVecType) {
+ Value inCast =
+ rewriter.create<vector::SplatOp>(loc, VectorType::get(1, inType), in);
+ // TODO: replace this with non-packed ScaledExtOp
+ Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>(
+ loc, extScaleResultType, inCast, scale, 0);
+ scaleExt = rewriter.replaceOpWithNewOp<vector::ExtractOp>(op, scaleExt, 0);
+ return success();
+ }
+
+ VectorType inVecType = cast<VectorType>(in.getType());
+ Value origScale = getOriginalVectorValue(scale);
+
+ int64_t scalarShape[1] = {1};
+ ArrayRef<int64_t> inShape = inVecType.getShape();
+ ArrayRef<int64_t> originalScaleShape = {scalarShape};
+ if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType()))
+ originalScaleShape = origScaleVecType.getShape();
+
+ SmallVector<int64_t> paddedScaleShape(originalScaleShape);
+ paddedScaleShape.insert(paddedScaleShape.end(),
+ inShape.size() - originalScaleShape.size(), 1);
+
+ auto maybeRatio = computeShapeRatio(inShape, paddedScaleShape);
+ assert(maybeRatio &&
+ "failed to derive block size from broadcast or splat operation");
+
+ SmallVector<int64_t> ratio =
+ maybeRatio.value_or(SmallVector<int64_t>(inShape.size(), 1));
+
+ int64_t blockSize = computeProduct(ratio);
+
+ Value zero = rewriter.create<arith::ConstantOp>(
+ loc, outType, rewriter.getFloatAttr(outType, 0.0));
+ Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
+
+ for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
+ SmallVector<int64_t> strides(offsets.size(), 1);
+ Value block = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, in, offsets, ratio, strides);
+ VectorType block1DType = VectorType::get(blockSize, inType);
+ Value block1D =
+ rewriter.create<vector::ShapeCastOp>(loc, block1DType, block);
+ Value uniformScale =
+ rewriter.create<vector::ExtractOp>(loc, scale, offsets);
+
+ VectorType blockResultType = VectorType::get(blockSize, outType);
+ Value blockResult =
+ rewriter.createOrFold<vector::SplatOp>(loc, blockResultType, zero);
+
+ for (int64_t i = 0, sliceWidth = opWidth - blockSize % opWidth;
+ i < blockSize;
+ i += sliceWidth, sliceWidth = opWidth - blockSize % opWidth) {
----------------
tgymnich wrote:
fixed and added test cases
https://github.com/llvm/llvm-project/pull/146372
More information about the Mlir-commits
mailing list