[Mlir-commits] [mlir] [mlir][tosa] Make TOSA MUL's Shift an Input (PR #121953)
Georgios Pinitas
llvmlistbot at llvm.org
Mon Jan 13 11:07:20 PST 2025
================
@@ -90,43 +90,59 @@ static Value createLinalgBodyCalculationForElementwiseOp(
}
// tosa::MulOp
- if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy))
- return rewriter.create<arith::MulFOp>(loc, resultTypes, args);
-
- if (isa<tosa::MulOp>(op) && isa<IntegerType>(elementTy)) {
- Value a = args[0];
- Value b = args[1];
- auto shift =
- cast<IntegerAttr>(op->getAttr("shift")).getValue().getSExtValue();
- if (shift > 0) {
- auto shiftConst =
- rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
- if (!a.getType().isInteger(32))
- a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a);
-
- if (!b.getType().isInteger(32))
- b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b);
-
- auto result = rewriter.create<tosa::ApplyScaleOp>(
- loc, rewriter.getI32Type(), a, b, shiftConst,
- rewriter.getBoolAttr(false));
-
- if (elementTy.isInteger(32))
- return result;
-
- return rewriter.create<arith::TruncIOp>(loc, elementTy, result);
+ if (isa<tosa::MulOp>(op)) {
+ auto shift_val = cast<tosa::MulOp>(op).getShift();
+ if (!elementTy.isInteger(32) && shift_val.getImpl()) {
+ (void)rewriter.notifyMatchFailure(
+ op, "Cannot have shift value for non i32 output");
+ return nullptr;
+ };
+
+ if (isa<FloatType>(elementTy)) {
+ return rewriter.create<arith::MulFOp>(loc, resultTypes, args[0], args[1]);
}
- int aWidth = a.getType().getIntOrFloatBitWidth();
- int bWidth = b.getType().getIntOrFloatBitWidth();
- int cWidth = resultTypes[0].getIntOrFloatBitWidth();
+ if (isa<IntegerType>(elementTy)) {
+ int32_t shift = 0;
+ ElementsAttr shift_elem;
+ if (shift_val.getImpl() &&
+ matchPattern(shift_val, m_Constant(&shift_elem))) {
+ // Explicit shift is set.
+ shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
+ }
+
+ Value a = args[0];
+ Value b = args[1];
+ if (shift > 0) {
----------------
GeorgeARM wrote:
What will happen here if shift is not constant?
https://github.com/llvm/llvm-project/pull/121953
More information about the Mlir-commits
mailing list