[Mlir-commits] [mlir] f39b472 - [mlir][arith][tosa] Use extended mul in 32-bit `tosa.apply_scale`

Jakub Kuderski llvmlistbot at llvm.org
Mon Dec 12 11:41:50 PST 2022


Author: Jakub Kuderski
Date: 2022-12-12T14:39:58-05:00
New Revision: f39b47264edda30b90ea9cc435fc1434a2e06edc

URL: https://github.com/llvm/llvm-project/commit/f39b47264edda30b90ea9cc435fc1434a2e06edc
DIFF: https://github.com/llvm/llvm-project/commit/f39b47264edda30b90ea9cc435fc1434a2e06edc.diff

LOG: [mlir][arith][tosa] Use extended mul in 32-bit `tosa.apply_scale`

To not introduce 64-bit types that may be difficult to handle for some
targets.

Reviewed By: rsuderman, antiagainst

Differential Revision: https://reviews.llvm.org/D139777

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
    mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
index f188b1bbe5dcc..fb0cf4f38d79e 100644
--- a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
+++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
@@ -127,7 +127,6 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
 
     Type resultTy = op.getType();
     Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy);
-    Type i64Ty = matchContainerType(rewriter.getI64Type(), resultTy);
 
     Value value = op.getValue();
     if (getElementTypeOrSelf(value.getType()).getIntOrFloatBitWidth() > 32) {
@@ -144,20 +143,13 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
     Value two32 = getConstantValue(loc, i32Ty, 2, rewriter);
     Value thirty32 = getConstantValue(loc, i32Ty, 30, rewriter);
     Value thirtyTwo32 = getConstantValue(loc, i32Ty, 32, rewriter);
-    Value thirtyTwo64 = getConstantValue(loc, i64Ty, 32, rewriter);
 
     // Compute the multiplication in 64-bits then select the high / low parts.
-    Value value64 = rewriter.create<arith::ExtSIOp>(loc, i64Ty, value32);
-    Value multiplier64 =
-        rewriter.create<arith::ExtSIOp>(loc, i64Ty, multiplier32);
-    Value multiply64 =
-        rewriter.create<arith::MulIOp>(loc, value64, multiplier64);
-
     // Grab out the high/low of the computation
-    Value high64 =
-        rewriter.create<arith::ShRUIOp>(loc, multiply64, thirtyTwo64);
-    Value high32 = rewriter.create<arith::TruncIOp>(loc, i32Ty, high64);
-    Value low32 = rewriter.create<arith::MulIOp>(loc, value32, multiplier32);
+    auto value64 =
+        rewriter.create<arith::MulSIExtendedOp>(loc, value32, multiplier32);
+    Value low32 = value64.getLow();
+    Value high32 = value64.getHigh();
 
     // Determine the direction and amount to shift the high bits.
     Value shiftOver32 = rewriter.create<arith::CmpIOp>(

diff  --git a/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir b/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
index 17a7ac3e76fd0..7f99e38d74199 100644
--- a/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
+++ b/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
@@ -21,15 +21,9 @@ func.func @apply_scale_test_i32(%arg0 : i32, %arg1 : i32, %arg2 : i8) -> (i32) {
   // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : i32
   // CHECK-DAG: %[[C30:.+]] = arith.constant 30 : i32
   // CHECK-DAG: %[[C32:.+]] = arith.constant 32 : i32
-  // CHECK-DAG: %[[C32L:.+]] = arith.constant 32 : i64
 
   // Compute the high-low values of the matmul in 64-bits.
-  // CHECK-DAG: %[[V64:.+]] = arith.extsi %arg0 : i32 to i64
-  // CHECK-DAG: %[[M64:.+]] = arith.extsi %arg1 : i32 to i64
-  // CHECK-DAG: %[[MUL64:.+]] = arith.muli %[[V64]], %[[M64]]
-  // CHECK-DAG: %[[HI64:.+]] = arith.shrui %[[MUL64]], %[[C32L]]
-  // CHECK-DAG: %[[HI:.+]] = arith.trunci %[[HI64]] : i64 to i32
-  // CHECK-DAG: %[[LOW:.+]] = arith.muli %arg0, %arg1
+  // CHECK-DAG: %[[LOW:.+]], %[[HI:.+]] = arith.mulsi_extended %arg0, %arg1
 
   // Determine whether the high bits need to shift left or right and by how much.
   // CHECK-DAG: %[[OVER31:.+]] = arith.cmpi sge, %[[S32]], %[[C32]]


        


More information about the Mlir-commits mailing list