[Mlir-commits] [mlir] [mlir][amdgpu] Add conversion from arith.scaling_extf to amdgpu (PR #146372)
Tim Gymnich
llvmlistbot at llvm.org
Tue Jul 1 04:03:48 PDT 2025
================
@@ -395,6 +424,242 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
return success();
}
+static Value getOriginalVectorValue(Value value) {
+ Value current = value;
+ while (Operation *definingOp = current.getDefiningOp()) {
+ bool skipOp = llvm::TypeSwitch<Operation *, bool>(definingOp)
+ .Case<vector::ShapeCastOp>([¤t](auto op) {
+ current = op.getSource();
+ return true;
+ })
+ .Case<vector::BroadcastOp>([¤t](auto op) {
+ current = op.getSource();
+ return false;
+ })
+ .Case<vector::SplatOp>([¤t](auto op) {
+ current = op.getInput();
+ return false;
+ })
+ .Default([](Operation *) { return false; });
+
+ if (!skipOp) {
+ break;
+ }
+ }
+ return current;
+}
+
+LogicalResult
+ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
+ PatternRewriter &rewriter) const {
+ Location loc = op.getLoc();
+ constexpr const int64_t opWidth = 2;
+
+ Value in = op.getIn();
+ Value scale = op.getScale();
+ Value out = op.getOut();
+
+ Type f32 = rewriter.getF32Type();
+ Type inType = getElementTypeOrSelf(in);
+ Type scaleType = getElementTypeOrSelf(scale);
+ Type outType = getElementTypeOrSelf(out);
+ VectorType scaleVecType = dyn_cast<VectorType>(scale.getType());
+ VectorType inVecType = dyn_cast<VectorType>(in.getType());
+ VectorType outVecType = dyn_cast<VectorType>(out.getType());
+
+ if (outVecType && outVecType.isScalable())
+ return failure();
+
+ Type scaleF32Type =
+ scaleVecType ? VectorType::get(scaleVecType.getShape(), f32) : f32;
+ if (scaleType.getIntOrFloatBitWidth() < 32)
+ scale = rewriter.create<arith::ExtFOp>(loc, scaleF32Type, scale);
+ else if (scaleType.getIntOrFloatBitWidth() > 32)
+ scale = rewriter.create<arith::TruncFOp>(loc, scaleF32Type, scale);
+
+ VectorType extScaleResultType = VectorType::get(opWidth, outType);
+
+ if (!outVecType) {
+ Value inCast =
+ rewriter.create<vector::SplatOp>(loc, VectorType::get(1, inType), in);
+ // TODO: replace this with non-packed ScaledExtOp
+ Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>(
+ loc, extScaleResultType, inCast, scale, 0);
+ scaleExt = rewriter.replaceOpWithNewOp<vector::ExtractOp>(op, scaleExt, 0);
+ return success();
+ }
+
+ Value origScale = getOriginalVectorValue(scale);
+ Type origScaleType = origScale.getType();
+ VectorType origScaleVecType = isa<VectorType>(origScaleType)
+ ? cast<VectorType>(origScaleType)
+ : VectorType::get(1, origScaleType);
+
+ ArrayRef<int64_t> originalScaleShape = origScaleVecType.getShape();
+ ArrayRef<int64_t> inShape = inVecType.getShape();
+
+ SmallVector<int64_t> paddedScaleShape(originalScaleShape);
+ paddedScaleShape.insert(paddedScaleShape.end(),
+ inShape.size() - originalScaleShape.size(), 1);
+
+ auto ratio = computeShapeRatio(inShape, paddedScaleShape);
+ if (!ratio)
+ return failure();
+
+ const int64_t blockSize = computeProduct(*ratio);
+
+ Value zero = rewriter.create<arith::ConstantOp>(
+ loc, outType, rewriter.getFloatAttr(outType, 0.0));
+ Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
+
+ for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, *ratio)) {
----------------
tgymnich wrote:
this vector is created in-place on the stack.
https://github.com/llvm/llvm-project/pull/146372
More information about the Mlir-commits
mailing list