[Mlir-commits] [mlir] [mlir][tosa] Make TOSA MUL's Shift an Input (PR #121953)

Georgios Pinitas llvmlistbot at llvm.org
Mon Jan 27 11:11:43 PST 2025


================
@@ -90,43 +90,59 @@ static Value createLinalgBodyCalculationForElementwiseOp(
   }
 
   // tosa::MulOp
-  if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy))
-    return rewriter.create<arith::MulFOp>(loc, resultTypes, args);
-
-  if (isa<tosa::MulOp>(op) && isa<IntegerType>(elementTy)) {
-    Value a = args[0];
-    Value b = args[1];
-    auto shift =
-        cast<IntegerAttr>(op->getAttr("shift")).getValue().getSExtValue();
-    if (shift > 0) {
-      auto shiftConst =
-          rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
-      if (!a.getType().isInteger(32))
-        a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a);
-
-      if (!b.getType().isInteger(32))
-        b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b);
-
-      auto result = rewriter.create<tosa::ApplyScaleOp>(
-          loc, rewriter.getI32Type(), a, b, shiftConst,
-          rewriter.getBoolAttr(false));
-
-      if (elementTy.isInteger(32))
-        return result;
-
-      return rewriter.create<arith::TruncIOp>(loc, elementTy, result);
+  if (isa<tosa::MulOp>(op)) {
+    auto shift_val = cast<tosa::MulOp>(op).getShift();
+    if (!elementTy.isInteger(32) && shift_val.getImpl()) {
+      (void)rewriter.notifyMatchFailure(
+          op, "Cannot have shift value for non i32 output");
+      return nullptr;
+    };
+
+    if (isa<FloatType>(elementTy)) {
+      return rewriter.create<arith::MulFOp>(loc, resultTypes, args[0], args[1]);
     }
 
-    int aWidth = a.getType().getIntOrFloatBitWidth();
-    int bWidth = b.getType().getIntOrFloatBitWidth();
-    int cWidth = resultTypes[0].getIntOrFloatBitWidth();
+    if (isa<IntegerType>(elementTy)) {
+      int32_t shift = 0;
+      ElementsAttr shift_elem;
+      if (shift_val.getImpl() &&
+          matchPattern(shift_val, m_Constant(&shift_elem))) {
+        // Explicit shift is set.
+        shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
+      }
+
+      Value a = args[0];
+      Value b = args[1];
+      if (shift > 0) {
----------------
GeorgeARM wrote:

Ack.

https://github.com/llvm/llvm-project/pull/121953


More information about the Mlir-commits mailing list