[Mlir-commits] [mlir] 285d321 - [mlir][arith] Define mulsi_extended op

Jakub Kuderski llvmlistbot at llvm.org
Fri Dec 9 17:26:28 PST 2022


Author: Jakub Kuderski
Date: 2022-12-09T20:25:31-05:00
New Revision: 285d321a855efb9415d0dcfa92141ae725dc7339

URL: https://github.com/llvm/llvm-project/commit/285d321a855efb9415d0dcfa92141ae725dc7339
DIFF: https://github.com/llvm/llvm-project/commit/285d321a855efb9415d0dcfa92141ae725dc7339.diff

LOG: [mlir][arith] Define mulsi_extended op

Extend D139688 with the signed version of the extended multiplication
op. Add conversion to the SPIR-V and LLVM dialects.

This was originally proposed in:
https://discourse.llvm.org/t/rfc-arith-add-extended-multiplication-ops/66869.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D139743

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
    mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
    mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
    mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
    mlir/lib/Dialect/Arith/IR/ArithOps.cpp
    mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
    mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
    mlir/test/Dialect/Arith/canonicalize.mlir
    mlir/test/Dialect/Arith/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index f02a2f89a987d..594ba46d62acb 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -268,7 +268,7 @@ def Arith_AddUIExtendedOp : Arith_Op<"addui_extended", [Pure, Commutative,
   let hasCanonicalizer = 1;
 
   let extraClassDeclaration = [{
-    ::llvm::Optional<::llvm::SmallVector<int64_t, 4>> getShapeForUnroll();
+    Optional<SmallVector<int64_t, 4>> getShapeForUnroll();
   }];
 }
 
@@ -291,6 +291,49 @@ def Arith_MulIOp : Arith_TotalIntBinaryOp<"muli", [Commutative]> {
   let hasFolder = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// MulSIExtendedOp
+//===----------------------------------------------------------------------===//
+
+def Arith_MulSIExtendedOp : Arith_Op<"mulsi_extended", [Pure, Commutative,
+    AllTypesMatch<["lhs", "rhs", "low", "high"]>]> {
+  let summary = [{
+    extended signed integer multiplication operation
+  }];
+
+  let description = [{
+    Performs (2*N)-bit multiplication on sign-extended operands. Returns two
+    N-bit results: the low and the high halves of the product. The low half has
+    the same value as the result of regular multiplication `arith.muli` with
+    the same operands.
+
+    Example:
+
+    ```mlir
+    // Scalar multiplication.
+    %low, %high = arith.mulsi_extended %a, %b : i32
+
+    // Vector element-wise multiplication.
+    %c:2 = arith.mulsi_extended %d, %e : vector<4xi32>
+
+    // Tensor element-wise multiplication.
+    %x:2 = arith.mulsi_extended %y, %z : tensor<4x?xi8>
+    ```
+  }];
+
+  let arguments = (ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs);
+  let results = (outs SignlessIntegerLike:$low, SignlessIntegerLike:$high);
+
+  let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs)";
+
+  let hasFolder = 1;
+  let hasCanonicalizer = 1;
+
+  let extraClassDeclaration = [{
+    Optional<SmallVector<int64_t, 4>> getShapeForUnroll();
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // MulUIExtendedOp
 //===----------------------------------------------------------------------===//
@@ -330,7 +373,7 @@ def Arith_MulUIExtendedOp : Arith_Op<"mului_extended", [Pure, Commutative,
   let hasCanonicalizer = 1;
 
   let extraClassDeclaration = [{
-    ::llvm::Optional<::llvm::SmallVector<int64_t, 4>> getShapeForUnroll();
+    Optional<SmallVector<int64_t, 4>> getShapeForUnroll();
   }];
 }
 

diff  --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 9483727f68508..b6317053e0a6d 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Pass/Pass.h"
+#include <type_traits>
 
 namespace mlir {
 #define GEN_PASS_DEF_ARITHTOLLVMCONVERSIONPASS
@@ -142,15 +143,20 @@ struct AddUIExtendedOpLowering
                   ConversionPatternRewriter &rewriter) const override;
 };
 
-struct MulUIExtendedOpLowering
-    : public ConvertOpToLLVMPattern<arith::MulUIExtendedOp> {
-  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+template <typename ArithMulOp, bool IsSigned>
+struct MulIExtendedOpLowering : public ConvertOpToLLVMPattern<ArithMulOp> {
+  using ConvertOpToLLVMPattern<ArithMulOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(arith::MulUIExtendedOp op, OpAdaptor adaptor,
+  matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+using MulSIExtendedOpLowering =
+    MulIExtendedOpLowering<arith::MulSIExtendedOp, true>;
+using MulUIExtendedOpLowering =
+    MulIExtendedOpLowering<arith::MulUIExtendedOp, false>;
+
 struct CmpIOpLowering : public ConvertOpToLLVMPattern<arith::CmpIOp> {
   using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
 
@@ -271,11 +277,12 @@ LogicalResult AddUIExtendedOpLowering::matchAndRewrite(
 }
 
 //===----------------------------------------------------------------------===//
-// MulUIExtendedOpLowering
+// MulIExtendedOpLowering
 //===----------------------------------------------------------------------===//
 
-LogicalResult MulUIExtendedOpLowering::matchAndRewrite(
-    arith::MulUIExtendedOp op, OpAdaptor adaptor,
+template <typename ArithMulOp, bool IsSigned>
+LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite(
+    ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
   Type resultType = adaptor.getLhs().getType();
 
@@ -308,10 +315,9 @@ LogicalResult MulUIExtendedOpLowering::matchAndRewrite(
     assert(LLVM::isCompatibleType(wideType) &&
            "LLVM dialect should support all signless integer types");
 
-    Value lhsExt =
-        rewriter.create<LLVM::ZExtOp>(loc, wideType, adaptor.getLhs());
-    Value rhsExt =
-        rewriter.create<LLVM::ZExtOp>(loc, wideType, adaptor.getRhs());
+    using LLVMExtOp = std::conditional_t<IsSigned, LLVM::SExtOp, LLVM::ZExtOp>;
+    Value lhsExt = rewriter.create<LLVMExtOp>(loc, wideType, adaptor.getLhs());
+    Value rhsExt = rewriter.create<LLVMExtOp>(loc, wideType, adaptor.getRhs());
     Value mulExt = rewriter.create<LLVM::MulOp>(loc, wideType, lhsExt, rhsExt);
 
     // Split the 2*N-bit wide result into two N-bit values.
@@ -467,6 +473,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
     MinUIOpLowering,
     MulFOpLowering,
     MulIOpLowering,
+    MulSIExtendedOpLowering,
     MulUIExtendedOpLowering,
     NegFOpLowering,
     OrIOpLowering,

diff  --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 03e877f56578d..ed5d044f82ed4 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -223,13 +223,13 @@ class AddUIExtendedOpPattern final
                   ConversionPatternRewriter &rewriter) const override;
 };
 
-/// Converts arith.mului_extended to spirv.UMulExtended.
-class MulUIExtendedOpPattern final
-    : public OpConversionPattern<arith::MulUIExtendedOp> {
+/// Converts arith.mul*i_extended to spirv.*MulExtended.
+template <typename ArithMulOp, typename SPIRVMulOp>
+class MulIExtendedOpPattern final : public OpConversionPattern<ArithMulOp> {
 public:
-  using OpConversionPattern::OpConversionPattern;
+  using OpConversionPattern<ArithMulOp>::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(arith::MulUIExtendedOp op, OpAdaptor adaptor,
+  matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
@@ -955,15 +955,16 @@ LogicalResult AddUIExtendedOpPattern::matchAndRewrite(
 }
 
 //===----------------------------------------------------------------------===//
-// MulUIExtendedOpPattern
+// MulIExtendedOpPattern
 //===----------------------------------------------------------------------===//
 
-LogicalResult MulUIExtendedOpPattern::matchAndRewrite(
-    arith::MulUIExtendedOp op, OpAdaptor adaptor,
+template <typename ArithMulOp, typename SPIRVMulOp>
+LogicalResult MulIExtendedOpPattern<ArithMulOp, SPIRVMulOp>::matchAndRewrite(
+    ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
   Location loc = op->getLoc();
-  Value result = rewriter.create<spirv::UMulExtendedOp>(loc, adaptor.getLhs(),
-                                                        adaptor.getRhs());
+  Value result =
+      rewriter.create<SPIRVMulOp>(loc, adaptor.getLhs(), adaptor.getRhs());
 
   Value low = rewriter.create<spirv::CompositeExtractOp>(loc, result,
                                                          llvm::makeArrayRef(0));
@@ -1070,7 +1071,10 @@ void mlir::arith::populateArithToSPIRVPatterns(
     TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
     CmpIOpBooleanPattern, CmpIOpPattern,
     CmpFOpNanNonePattern, CmpFOpPattern,
-    AddUIExtendedOpPattern, MulUIExtendedOpPattern, SelectOpPattern,
+    AddUIExtendedOpPattern,
+    MulIExtendedOpPattern<arith::MulSIExtendedOp, spirv::SMulExtendedOp>,
+    MulIExtendedOpPattern<arith::MulUIExtendedOp, spirv::UMulExtendedOp>,
+    SelectOpPattern,
 
     MinMaxFOpPattern<arith::MaxFOp, spirv::GLFMaxOp>,
     MinMaxFOpPattern<arith::MinFOp, spirv::GLFMinOp>,

diff  --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 7637369c4e552..cf2a7678c6eb7 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -111,6 +111,17 @@ def SubISubILHSRHSLHS :
     Pat<(Arith_SubIOp:$res (Arith_SubIOp $x, $y), $x),
         (Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y)>;
 
+//===----------------------------------------------------------------------===//
+// MulSIExtendedOp
+//===----------------------------------------------------------------------===//
+
+// mulsi_extended(x, y) -> [muli(x, y), x], when the `high` result is unused.
+// Since the `high` result it not used, any replacement value will do.
+def MulSIExtendedToMulI :
+    Pattern<(Arith_MulSIExtendedOp:$res $x, $y),
+        [(Arith_MulIOp $x, $y), (replaceWithValue $x)],
+        [(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;
+
 //===----------------------------------------------------------------------===//
 // MulUIExtendedOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index fde37d6bd69d4..25a3dd425cf22 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -350,6 +350,52 @@ OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) {
       operands, [](const APInt &a, const APInt &b) { return a * b; });
 }
 
+//===----------------------------------------------------------------------===//
+// MulSIExtendedOp
+//===----------------------------------------------------------------------===//
+
+Optional<SmallVector<int64_t, 4>> arith::MulSIExtendedOp::getShapeForUnroll() {
+  if (auto vt = getType(0).dyn_cast<VectorType>())
+    return llvm::to_vector<4>(vt.getShape());
+  return std::nullopt;
+}
+
+LogicalResult
+arith::MulSIExtendedOp::fold(ArrayRef<Attribute> operands,
+                             SmallVectorImpl<OpFoldResult> &results) {
+  // mulsi_extended(x, 0) -> 0, 0
+  if (matchPattern(getRhs(), m_Zero())) {
+    Attribute zero = operands[1];
+    results.push_back(zero);
+    results.push_back(zero);
+    return success();
+  }
+
+  // mulsi_extended(cst_a, cst_b) -> cst_low, cst_high
+  if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
+          operands, [](const APInt &a, const APInt &b) { return a * b; })) {
+    // Invoke the constant fold helper again to calculate the 'high' result.
+    Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
+        operands, [](const APInt &a, const APInt &b) {
+          unsigned bitWidth = a.getBitWidth();
+          APInt fullProduct = a.sext(bitWidth * 2) * b.sext(bitWidth * 2);
+          return fullProduct.extractBits(bitWidth, bitWidth);
+        });
+    assert(highAttr && "Unexpected constant-folding failure");
+
+    results.push_back(lowAttr);
+    results.push_back(highAttr);
+    return success();
+  }
+
+  return failure();
+}
+
+void arith::MulSIExtendedOp::getCanonicalizationPatterns(
+    RewritePatternSet &patterns, MLIRContext *context) {
+  patterns.add<MulSIExtendedToMulI>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // MulUIExtendedOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 8abf8130744a5..637f9daa1a0b3 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -379,6 +379,38 @@ func.func @addui_extended_vector1d(%arg0: vector<3xi16>, %arg1: vector<3xi16>) -
 
 // -----
 
+// CHECK-LABEL: @mulsi_extended_scalar
+// CHECK-SAME:    ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> (i32, i32)
+func.func @mulsi_extended_scalar(%arg0: i32, %arg1: i32) -> (i32, i32) {
+  // CHECK-NEXT: [[LHS:%.+]]  = llvm.sext [[ARG0]] : i32 to i64
+  // CHECK-NEXT: [[RHS:%.+]]  = llvm.sext [[ARG1]] : i32 to i64
+  // CHECK-NEXT: [[MUL:%.+]]  = llvm.mul [[LHS]], [[RHS]] : i64
+  // CHECK-NEXT: [[LOW:%.+]]  = llvm.trunc [[MUL]] : i64 to i32
+  // CHECK-NEXT: [[C32:%.+]]  = llvm.mlir.constant(32 : i64) : i64
+  // CHECK-NEXT: [[SHL:%.+]]  = llvm.lshr [[MUL]], [[C32]] : i64
+  // CHECK-NEXT: [[HIGH:%.+]] = llvm.trunc [[SHL]] : i64 to i32
+  %low, %high = arith.mulsi_extended %arg0, %arg1 : i32
+  // CHECK-NEXT: return [[LOW]], [[HIGH]] : i32, i32
+  return %low, %high : i32, i32
+}
+
+// CHECK-LABEL: @mulsi_extended_vector1d
+// CHECK-SAME:    ([[ARG0:%.+]]: vector<3xi64>, [[ARG1:%.+]]: vector<3xi64>) -> (vector<3xi64>, vector<3xi64>)
+func.func @mulsi_extended_vector1d(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> (vector<3xi64>, vector<3xi64>) {
+  // CHECK-NEXT: [[LHS:%.+]]  = llvm.sext [[ARG0]] : vector<3xi64> to vector<3xi128>
+  // CHECK-NEXT: [[RHS:%.+]]  = llvm.sext [[ARG1]] : vector<3xi64> to vector<3xi128>
+  // CHECK-NEXT: [[MUL:%.+]]  = llvm.mul [[LHS]], [[RHS]] : vector<3xi128>
+  // CHECK-NEXT: [[LOW:%.+]]  = llvm.trunc [[MUL]] : vector<3xi128> to vector<3xi64>
+  // CHECK-NEXT: [[C64:%.+]]  = llvm.mlir.constant(dense<64> : vector<3xi128>) : vector<3xi128>
+  // CHECK-NEXT: [[SHL:%.+]]  = llvm.lshr [[MUL]], [[C64]] : vector<3xi128>
+  // CHECK-NEXT: [[HIGH:%.+]] = llvm.trunc [[SHL]] : vector<3xi128> to vector<3xi64>
+  %low, %high = arith.mulsi_extended %arg0, %arg1 : vector<3xi64>
+  // CHECK-NEXT: return [[LOW]], [[HIGH]] : vector<3xi64>, vector<3xi64>
+  return %low, %high : vector<3xi64>, vector<3xi64>
+}
+
+// -----
+
 // CHECK-LABEL: @mului_extended_scalar
 // CHECK-SAME:    ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> (i32, i32)
 func.func @mului_extended_scalar(%arg0: i32, %arg1: i32) -> (i32, i32) {

diff  --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index c2f642e3fb264..3b18295326dab 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -99,6 +99,29 @@ func.func @int32_vector_addui_extended(%lhs: vector<4xi32>, %rhs: vector<4xi32>)
   return %sum, %overflow : vector<4xi32>, vector<4xi1>
 }
 
+// Check extended signed integer multiplication conversions.
+// CHECK-LABEL: @int32_scalar_mulsi_extended
+// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32)
+func.func @int32_scalar_mulsi_extended(%lhs: i32, %rhs: i32) -> (i32, i32) {
+  // CHECK-NEXT: %[[MUL:.+]]   = spirv.SMulExtended %[[LHS]], %[[RHS]] : !spirv.struct<(i32, i32)>
+  // CHECK-DAG:  %[[LOW:.+]]   = spirv.CompositeExtract %[[MUL]][0 : i32] : !spirv.struct<(i32, i32)>
+  // CHECK-DAG:  %[[HIGH:.+]]  = spirv.CompositeExtract %[[MUL]][1 : i32] : !spirv.struct<(i32, i32)>
+  // CHECK-NEXT: return %[[LOW]], %[[HIGH]] : i32, i32
+  %low, %high = arith.mulsi_extended %lhs, %rhs: i32
+  return %low, %high : i32, i32
+}
+
+// CHECK-LABEL: @int32_vector_mulsi_extended
+// CHECK-SAME: (%[[LHS:.+]]: vector<4xi32>, %[[RHS:.+]]: vector<4xi32>)
+func.func @int32_vector_mulsi_extended(%lhs: vector<4xi32>, %rhs: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) {
+  // CHECK-NEXT: %[[MUL:.+]]   = spirv.SMulExtended %[[LHS]], %[[RHS]] : !spirv.struct<(vector<4xi32>, vector<4xi32>)>
+  // CHECK-DAG:  %[[LOW:.+]]   = spirv.CompositeExtract %[[MUL]][0 : i32] : !spirv.struct<(vector<4xi32>, vector<4xi32>)>
+  // CHECK-DAG:  %[[HIGH:.+]]  = spirv.CompositeExtract %[[MUL]][1 : i32] : !spirv.struct<(vector<4xi32>, vector<4xi32>)>
+  // CHECK-NEXT: return %[[LOW]], %[[HIGH]] : vector<4xi32>, vector<4xi32>
+  %low, %high = arith.mulsi_extended %lhs, %rhs: vector<4xi32>
+  return %low, %high : vector<4xi32>, vector<4xi32>
+}
+
 // Check extended unsigned integer multiplication conversions.
 // CHECK-LABEL: @int32_scalar_mului_extended
 // CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32)

diff  --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 1280a2adece89..644af88e298c7 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -734,6 +734,64 @@ func.func @adduiExtendedConstantsSplatVector() -> (vector<4xi32>, vector<4xi1>)
   return %sum, %overflow : vector<4xi32>, vector<4xi1>
 }
 
+// CHECK-LABEL: @mulsiExtendedZeroRhs
+//  CHECK-NEXT:   %[[zero:.+]] = arith.constant 0 : i32
+//  CHECK-NEXT:   return %[[zero]], %[[zero]]
+func.func @mulsiExtendedZeroRhs(%arg0: i32) -> (i32, i32) {
+  %zero = arith.constant 0 : i32
+  %low, %high = arith.mulsi_extended %arg0, %zero: i32
+  return %low, %high : i32, i32
+}
+
+// CHECK-LABEL: @mulsiExtendedZeroRhsSplat
+//  CHECK-NEXT:   %[[zero:.+]] = arith.constant dense<0> : vector<3xi32>
+//  CHECK-NEXT:   return %[[zero]], %[[zero]]
+func.func @mulsiExtendedZeroRhsSplat(%arg0: vector<3xi32>) -> (vector<3xi32>, vector<3xi32>) {
+  %zero = arith.constant dense<0> : vector<3xi32>
+  %low, %high = arith.mulsi_extended %arg0, %zero: vector<3xi32>
+  return %low, %high : vector<3xi32>, vector<3xi32>
+}
+
+// CHECK-LABEL: @mulsiExtendedZeroLhs
+//  CHECK-NEXT:   %[[zero:.+]] = arith.constant 0 : i32
+//  CHECK-NEXT:   return %[[zero]], %[[zero]]
+func.func @mulsiExtendedZeroLhs(%arg0: i32) -> (i32, i32) {
+  %zero = arith.constant 0 : i32
+  %low, %high = arith.mulsi_extended %zero, %arg0: i32
+  return %low, %high : i32, i32
+}
+
+// CHECK-LABEL: @mulsiExtendedUnusedHigh
+//  CHECK-SAME:   (%[[ARG:.+]]: i32) -> i32
+//  CHECK-NEXT:   %[[RES:.+]] = arith.muli %[[ARG]], %[[ARG]] : i32
+//  CHECK-NEXT:   return %[[RES]]
+func.func @mulsiExtendedUnusedHigh(%arg0: i32) -> i32 {
+  %low, %high = arith.mulsi_extended %arg0, %arg0: i32
+  return %low : i32
+}
+
+// CHECK-LABEL: @mulsiExtendedScalarConstants
+//  CHECK-DAG:    %[[c27:.+]] = arith.constant 27 : i8
+//  CHECK-DAG:    %[[c_n3:.+]] = arith.constant -3 : i8
+//  CHECK-NEXT:   return %[[c27]], %[[c_n3]]
+func.func @mulsiExtendedScalarConstants() -> (i8, i8) {
+  %c57 = arith.constant 57 : i8
+  %c_n13 = arith.constant -13 : i8
+  %low, %high = arith.mulsi_extended %c57, %c_n13: i8
+  return %low, %high : i8, i8
+}
+
+// CHECK-LABEL: @mulsiExtendedVectorConstants
+//  CHECK-DAG:    %[[cstLo:.+]] = arith.constant dense<[65, 79, 34]> : vector<3xi8>
+//  CHECK-DAG:    %[[cstHi:.+]] = arith.constant dense<[0, 14, 0]> : vector<3xi8>
+//  CHECK-NEXT:   return %[[cstLo]], %[[cstHi]]
+func.func @mulsiExtendedVectorConstants() -> (vector<3xi8>, vector<3xi8>) {
+  %cstA = arith.constant dense<[5, 37, -17]> : vector<3xi8>
+  %cstB = arith.constant dense<[13, 99, -2]> : vector<3xi8>
+  %low, %high = arith.mulsi_extended %cstA, %cstB: vector<3xi8>
+  return %low, %high : vector<3xi8>, vector<3xi8>
+}
+
 // CHECK-LABEL: @muluiExtendedZeroRhs
 //  CHECK-NEXT:   %[[zero:.+]] = arith.constant 0 : i32
 //  CHECK-NEXT:   return %[[zero]], %[[zero]]

diff  --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index a9c0b557fdf01..25c015bfa3ccc 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -97,6 +97,30 @@ func.func @test_muli_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]
   return %0 : vector<[8]xi64>
 }
 
+// CHECK-LABEL: test_mulsi_extended
+func.func @test_mulsi_extended(%arg0 : i32, %arg1 : i32) -> i32 {
+  %low, %high = arith.mulsi_extended %arg0, %arg1 : i32
+  return %high : i32
+}
+
+// CHECK-LABEL: test_mulsi_extended_tensor
+func.func @test_mulsi_extended_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) -> tensor<8x8xi64> {
+  %low, %high = arith.mulsi_extended %arg0, %arg1 : tensor<8x8xi64>
+  return %high : tensor<8x8xi64>
+}
+
+// CHECK-LABEL: test_mulsi_extended_vector
+func.func @test_mulsi_extended_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8xi64> {
+  %0:2 = arith.mulsi_extended %arg0, %arg1 : vector<8xi64>
+  return %0#0 : vector<8xi64>
+}
+
+// CHECK-LABEL: test_mulsi_extended_scalable_vector
+func.func @test_mulsi_extended_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
+  %0:2 = arith.mulsi_extended %arg0, %arg1 : vector<[8]xi64>
+  return %0#1 : vector<[8]xi64>
+}
+
 // CHECK-LABEL: test_mului_extended
 func.func @test_mului_extended(%arg0 : i32, %arg1 : i32) -> i32 {
   %low, %high = arith.mului_extended %arg0, %arg1 : i32


        


More information about the Mlir-commits mailing list