[Mlir-commits] [mlir] 285d321 - [mlir][arith] Define mulsi_extended op
Jakub Kuderski
llvmlistbot at llvm.org
Fri Dec 9 17:26:28 PST 2022
Author: Jakub Kuderski
Date: 2022-12-09T20:25:31-05:00
New Revision: 285d321a855efb9415d0dcfa92141ae725dc7339
URL: https://github.com/llvm/llvm-project/commit/285d321a855efb9415d0dcfa92141ae725dc7339
DIFF: https://github.com/llvm/llvm-project/commit/285d321a855efb9415d0dcfa92141ae725dc7339.diff
LOG: [mlir][arith] Define mulsi_extended op
Extend D139688 with the signed version of the extended multiplication
op. Add conversion to the SPIR-V and LLVM dialects.
This was originally proposed in:
https://discourse.llvm.org/t/rfc-arith-add-extended-multiplication-ops/66869.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D139743
Added:
Modified:
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
mlir/test/Dialect/Arith/canonicalize.mlir
mlir/test/Dialect/Arith/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index f02a2f89a987d..594ba46d62acb 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -268,7 +268,7 @@ def Arith_AddUIExtendedOp : Arith_Op<"addui_extended", [Pure, Commutative,
let hasCanonicalizer = 1;
let extraClassDeclaration = [{
- ::llvm::Optional<::llvm::SmallVector<int64_t, 4>> getShapeForUnroll();
+ Optional<SmallVector<int64_t, 4>> getShapeForUnroll();
}];
}
@@ -291,6 +291,49 @@ def Arith_MulIOp : Arith_TotalIntBinaryOp<"muli", [Commutative]> {
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// MulSIExtendedOp
+//===----------------------------------------------------------------------===//
+
+def Arith_MulSIExtendedOp : Arith_Op<"mulsi_extended", [Pure, Commutative,
+ AllTypesMatch<["lhs", "rhs", "low", "high"]>]> {
+ let summary = [{
+ extended signed integer multiplication operation
+ }];
+
+ let description = [{
+ Performs (2*N)-bit multiplication on sign-extended operands. Returns two
+ N-bit results: the low and the high halves of the product. The low half has
+ the same value as the result of regular multiplication `arith.muli` with
+ the same operands.
+
+ Example:
+
+ ```mlir
+ // Scalar multiplication.
+ %low, %high = arith.mulsi_extended %a, %b : i32
+
+ // Vector element-wise multiplication.
+ %c:2 = arith.mulsi_extended %d, %e : vector<4xi32>
+
+ // Tensor element-wise multiplication.
+ %x:2 = arith.mulsi_extended %y, %z : tensor<4x?xi8>
+ ```
+ }];
+
+ let arguments = (ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs);
+ let results = (outs SignlessIntegerLike:$low, SignlessIntegerLike:$high);
+
+ let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs)";
+
+ let hasFolder = 1;
+ let hasCanonicalizer = 1;
+
+ let extraClassDeclaration = [{
+ Optional<SmallVector<int64_t, 4>> getShapeForUnroll();
+ }];
+}
+
//===----------------------------------------------------------------------===//
// MulUIExtendedOp
//===----------------------------------------------------------------------===//
@@ -330,7 +373,7 @@ def Arith_MulUIExtendedOp : Arith_Op<"mului_extended", [Pure, Commutative,
let hasCanonicalizer = 1;
let extraClassDeclaration = [{
- ::llvm::Optional<::llvm::SmallVector<int64_t, 4>> getShapeForUnroll();
+ Optional<SmallVector<int64_t, 4>> getShapeForUnroll();
}];
}
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 9483727f68508..b6317053e0a6d 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
+#include <type_traits>
namespace mlir {
#define GEN_PASS_DEF_ARITHTOLLVMCONVERSIONPASS
@@ -142,15 +143,20 @@ struct AddUIExtendedOpLowering
ConversionPatternRewriter &rewriter) const override;
};
-struct MulUIExtendedOpLowering
- : public ConvertOpToLLVMPattern<arith::MulUIExtendedOp> {
- using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+template <typename ArithMulOp, bool IsSigned>
+struct MulIExtendedOpLowering : public ConvertOpToLLVMPattern<ArithMulOp> {
+ using ConvertOpToLLVMPattern<ArithMulOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(arith::MulUIExtendedOp op, OpAdaptor adaptor,
+ matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
+using MulSIExtendedOpLowering =
+ MulIExtendedOpLowering<arith::MulSIExtendedOp, true>;
+using MulUIExtendedOpLowering =
+ MulIExtendedOpLowering<arith::MulUIExtendedOp, false>;
+
struct CmpIOpLowering : public ConvertOpToLLVMPattern<arith::CmpIOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
@@ -271,11 +277,12 @@ LogicalResult AddUIExtendedOpLowering::matchAndRewrite(
}
//===----------------------------------------------------------------------===//
-// MulUIExtendedOpLowering
+// MulIExtendedOpLowering
//===----------------------------------------------------------------------===//
-LogicalResult MulUIExtendedOpLowering::matchAndRewrite(
- arith::MulUIExtendedOp op, OpAdaptor adaptor,
+template <typename ArithMulOp, bool IsSigned>
+LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite(
+ ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Type resultType = adaptor.getLhs().getType();
@@ -308,10 +315,9 @@ LogicalResult MulUIExtendedOpLowering::matchAndRewrite(
assert(LLVM::isCompatibleType(wideType) &&
"LLVM dialect should support all signless integer types");
- Value lhsExt =
- rewriter.create<LLVM::ZExtOp>(loc, wideType, adaptor.getLhs());
- Value rhsExt =
- rewriter.create<LLVM::ZExtOp>(loc, wideType, adaptor.getRhs());
+ using LLVMExtOp = std::conditional_t<IsSigned, LLVM::SExtOp, LLVM::ZExtOp>;
+ Value lhsExt = rewriter.create<LLVMExtOp>(loc, wideType, adaptor.getLhs());
+ Value rhsExt = rewriter.create<LLVMExtOp>(loc, wideType, adaptor.getRhs());
Value mulExt = rewriter.create<LLVM::MulOp>(loc, wideType, lhsExt, rhsExt);
// Split the 2*N-bit wide result into two N-bit values.
@@ -467,6 +473,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
MinUIOpLowering,
MulFOpLowering,
MulIOpLowering,
+ MulSIExtendedOpLowering,
MulUIExtendedOpLowering,
NegFOpLowering,
OrIOpLowering,
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 03e877f56578d..ed5d044f82ed4 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -223,13 +223,13 @@ class AddUIExtendedOpPattern final
ConversionPatternRewriter &rewriter) const override;
};
-/// Converts arith.mului_extended to spirv.UMulExtended.
-class MulUIExtendedOpPattern final
- : public OpConversionPattern<arith::MulUIExtendedOp> {
+/// Converts arith.mul*i_extended to spirv.*MulExtended.
+template <typename ArithMulOp, typename SPIRVMulOp>
+class MulIExtendedOpPattern final : public OpConversionPattern<ArithMulOp> {
public:
- using OpConversionPattern::OpConversionPattern;
+ using OpConversionPattern<ArithMulOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(arith::MulUIExtendedOp op, OpAdaptor adaptor,
+ matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
@@ -955,15 +955,16 @@ LogicalResult AddUIExtendedOpPattern::matchAndRewrite(
}
//===----------------------------------------------------------------------===//
-// MulUIExtendedOpPattern
+// MulIExtendedOpPattern
//===----------------------------------------------------------------------===//
-LogicalResult MulUIExtendedOpPattern::matchAndRewrite(
- arith::MulUIExtendedOp op, OpAdaptor adaptor,
+template <typename ArithMulOp, typename SPIRVMulOp>
+LogicalResult MulIExtendedOpPattern<ArithMulOp, SPIRVMulOp>::matchAndRewrite(
+ ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op->getLoc();
- Value result = rewriter.create<spirv::UMulExtendedOp>(loc, adaptor.getLhs(),
- adaptor.getRhs());
+ Value result =
+ rewriter.create<SPIRVMulOp>(loc, adaptor.getLhs(), adaptor.getRhs());
Value low = rewriter.create<spirv::CompositeExtractOp>(loc, result,
llvm::makeArrayRef(0));
@@ -1070,7 +1071,10 @@ void mlir::arith::populateArithToSPIRVPatterns(
TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
CmpIOpBooleanPattern, CmpIOpPattern,
CmpFOpNanNonePattern, CmpFOpPattern,
- AddUIExtendedOpPattern, MulUIExtendedOpPattern, SelectOpPattern,
+ AddUIExtendedOpPattern,
+ MulIExtendedOpPattern<arith::MulSIExtendedOp, spirv::SMulExtendedOp>,
+ MulIExtendedOpPattern<arith::MulUIExtendedOp, spirv::UMulExtendedOp>,
+ SelectOpPattern,
MinMaxFOpPattern<arith::MaxFOp, spirv::GLFMaxOp>,
MinMaxFOpPattern<arith::MinFOp, spirv::GLFMinOp>,
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 7637369c4e552..cf2a7678c6eb7 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -111,6 +111,17 @@ def SubISubILHSRHSLHS :
Pat<(Arith_SubIOp:$res (Arith_SubIOp $x, $y), $x),
(Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y)>;
+//===----------------------------------------------------------------------===//
+// MulSIExtendedOp
+//===----------------------------------------------------------------------===//
+
+// mulsi_extended(x, y) -> [muli(x, y), x], when the `high` result is unused.
+// Since the `high` result it not used, any replacement value will do.
+def MulSIExtendedToMulI :
+ Pattern<(Arith_MulSIExtendedOp:$res $x, $y),
+ [(Arith_MulIOp $x, $y), (replaceWithValue $x)],
+ [(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;
+
//===----------------------------------------------------------------------===//
// MulUIExtendedOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index fde37d6bd69d4..25a3dd425cf22 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -350,6 +350,52 @@ OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) {
operands, [](const APInt &a, const APInt &b) { return a * b; });
}
+//===----------------------------------------------------------------------===//
+// MulSIExtendedOp
+//===----------------------------------------------------------------------===//
+
+Optional<SmallVector<int64_t, 4>> arith::MulSIExtendedOp::getShapeForUnroll() {
+ if (auto vt = getType(0).dyn_cast<VectorType>())
+ return llvm::to_vector<4>(vt.getShape());
+ return std::nullopt;
+}
+
+LogicalResult
+arith::MulSIExtendedOp::fold(ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ // mulsi_extended(x, 0) -> 0, 0
+ if (matchPattern(getRhs(), m_Zero())) {
+ Attribute zero = operands[1];
+ results.push_back(zero);
+ results.push_back(zero);
+ return success();
+ }
+
+ // mulsi_extended(cst_a, cst_b) -> cst_low, cst_high
+ if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
+ operands, [](const APInt &a, const APInt &b) { return a * b; })) {
+ // Invoke the constant fold helper again to calculate the 'high' result.
+ Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
+ operands, [](const APInt &a, const APInt &b) {
+ unsigned bitWidth = a.getBitWidth();
+ APInt fullProduct = a.sext(bitWidth * 2) * b.sext(bitWidth * 2);
+ return fullProduct.extractBits(bitWidth, bitWidth);
+ });
+ assert(highAttr && "Unexpected constant-folding failure");
+
+ results.push_back(lowAttr);
+ results.push_back(highAttr);
+ return success();
+ }
+
+ return failure();
+}
+
+void arith::MulSIExtendedOp::getCanonicalizationPatterns(
+ RewritePatternSet &patterns, MLIRContext *context) {
+ patterns.add<MulSIExtendedToMulI>(context);
+}
+
//===----------------------------------------------------------------------===//
// MulUIExtendedOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 8abf8130744a5..637f9daa1a0b3 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -379,6 +379,38 @@ func.func @addui_extended_vector1d(%arg0: vector<3xi16>, %arg1: vector<3xi16>) -
// -----
+// CHECK-LABEL: @mulsi_extended_scalar
+// CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> (i32, i32)
+func.func @mulsi_extended_scalar(%arg0: i32, %arg1: i32) -> (i32, i32) {
+ // CHECK-NEXT: [[LHS:%.+]] = llvm.sext [[ARG0]] : i32 to i64
+ // CHECK-NEXT: [[RHS:%.+]] = llvm.sext [[ARG1]] : i32 to i64
+ // CHECK-NEXT: [[MUL:%.+]] = llvm.mul [[LHS]], [[RHS]] : i64
+ // CHECK-NEXT: [[LOW:%.+]] = llvm.trunc [[MUL]] : i64 to i32
+ // CHECK-NEXT: [[C32:%.+]] = llvm.mlir.constant(32 : i64) : i64
+ // CHECK-NEXT: [[SHL:%.+]] = llvm.lshr [[MUL]], [[C32]] : i64
+ // CHECK-NEXT: [[HIGH:%.+]] = llvm.trunc [[SHL]] : i64 to i32
+ %low, %high = arith.mulsi_extended %arg0, %arg1 : i32
+ // CHECK-NEXT: return [[LOW]], [[HIGH]] : i32, i32
+ return %low, %high : i32, i32
+}
+
+// CHECK-LABEL: @mulsi_extended_vector1d
+// CHECK-SAME: ([[ARG0:%.+]]: vector<3xi64>, [[ARG1:%.+]]: vector<3xi64>) -> (vector<3xi64>, vector<3xi64>)
+func.func @mulsi_extended_vector1d(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> (vector<3xi64>, vector<3xi64>) {
+ // CHECK-NEXT: [[LHS:%.+]] = llvm.sext [[ARG0]] : vector<3xi64> to vector<3xi128>
+ // CHECK-NEXT: [[RHS:%.+]] = llvm.sext [[ARG1]] : vector<3xi64> to vector<3xi128>
+ // CHECK-NEXT: [[MUL:%.+]] = llvm.mul [[LHS]], [[RHS]] : vector<3xi128>
+ // CHECK-NEXT: [[LOW:%.+]] = llvm.trunc [[MUL]] : vector<3xi128> to vector<3xi64>
+ // CHECK-NEXT: [[C64:%.+]] = llvm.mlir.constant(dense<64> : vector<3xi128>) : vector<3xi128>
+ // CHECK-NEXT: [[SHL:%.+]] = llvm.lshr [[MUL]], [[C64]] : vector<3xi128>
+ // CHECK-NEXT: [[HIGH:%.+]] = llvm.trunc [[SHL]] : vector<3xi128> to vector<3xi64>
+ %low, %high = arith.mulsi_extended %arg0, %arg1 : vector<3xi64>
+ // CHECK-NEXT: return [[LOW]], [[HIGH]] : vector<3xi64>, vector<3xi64>
+ return %low, %high : vector<3xi64>, vector<3xi64>
+}
+
+// -----
+
// CHECK-LABEL: @mului_extended_scalar
// CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> (i32, i32)
func.func @mului_extended_scalar(%arg0: i32, %arg1: i32) -> (i32, i32) {
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index c2f642e3fb264..3b18295326dab 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -99,6 +99,29 @@ func.func @int32_vector_addui_extended(%lhs: vector<4xi32>, %rhs: vector<4xi32>)
return %sum, %overflow : vector<4xi32>, vector<4xi1>
}
+// Check extended signed integer multiplication conversions.
+// CHECK-LABEL: @int32_scalar_mulsi_extended
+// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32)
+func.func @int32_scalar_mulsi_extended(%lhs: i32, %rhs: i32) -> (i32, i32) {
+ // CHECK-NEXT: %[[MUL:.+]] = spirv.SMulExtended %[[LHS]], %[[RHS]] : !spirv.struct<(i32, i32)>
+ // CHECK-DAG: %[[LOW:.+]] = spirv.CompositeExtract %[[MUL]][0 : i32] : !spirv.struct<(i32, i32)>
+ // CHECK-DAG: %[[HIGH:.+]] = spirv.CompositeExtract %[[MUL]][1 : i32] : !spirv.struct<(i32, i32)>
+ // CHECK-NEXT: return %[[LOW]], %[[HIGH]] : i32, i32
+ %low, %high = arith.mulsi_extended %lhs, %rhs: i32
+ return %low, %high : i32, i32
+}
+
+// CHECK-LABEL: @int32_vector_mulsi_extended
+// CHECK-SAME: (%[[LHS:.+]]: vector<4xi32>, %[[RHS:.+]]: vector<4xi32>)
+func.func @int32_vector_mulsi_extended(%lhs: vector<4xi32>, %rhs: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) {
+ // CHECK-NEXT: %[[MUL:.+]] = spirv.SMulExtended %[[LHS]], %[[RHS]] : !spirv.struct<(vector<4xi32>, vector<4xi32>)>
+ // CHECK-DAG: %[[LOW:.+]] = spirv.CompositeExtract %[[MUL]][0 : i32] : !spirv.struct<(vector<4xi32>, vector<4xi32>)>
+ // CHECK-DAG: %[[HIGH:.+]] = spirv.CompositeExtract %[[MUL]][1 : i32] : !spirv.struct<(vector<4xi32>, vector<4xi32>)>
+ // CHECK-NEXT: return %[[LOW]], %[[HIGH]] : vector<4xi32>, vector<4xi32>
+ %low, %high = arith.mulsi_extended %lhs, %rhs: vector<4xi32>
+ return %low, %high : vector<4xi32>, vector<4xi32>
+}
+
// Check extended unsigned integer multiplication conversions.
// CHECK-LABEL: @int32_scalar_mului_extended
// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32)
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 1280a2adece89..644af88e298c7 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -734,6 +734,64 @@ func.func @adduiExtendedConstantsSplatVector() -> (vector<4xi32>, vector<4xi1>)
return %sum, %overflow : vector<4xi32>, vector<4xi1>
}
+// CHECK-LABEL: @mulsiExtendedZeroRhs
+// CHECK-NEXT: %[[zero:.+]] = arith.constant 0 : i32
+// CHECK-NEXT: return %[[zero]], %[[zero]]
+func.func @mulsiExtendedZeroRhs(%arg0: i32) -> (i32, i32) {
+ %zero = arith.constant 0 : i32
+ %low, %high = arith.mulsi_extended %arg0, %zero: i32
+ return %low, %high : i32, i32
+}
+
+// CHECK-LABEL: @mulsiExtendedZeroRhsSplat
+// CHECK-NEXT: %[[zero:.+]] = arith.constant dense<0> : vector<3xi32>
+// CHECK-NEXT: return %[[zero]], %[[zero]]
+func.func @mulsiExtendedZeroRhsSplat(%arg0: vector<3xi32>) -> (vector<3xi32>, vector<3xi32>) {
+ %zero = arith.constant dense<0> : vector<3xi32>
+ %low, %high = arith.mulsi_extended %arg0, %zero: vector<3xi32>
+ return %low, %high : vector<3xi32>, vector<3xi32>
+}
+
+// CHECK-LABEL: @mulsiExtendedZeroLhs
+// CHECK-NEXT: %[[zero:.+]] = arith.constant 0 : i32
+// CHECK-NEXT: return %[[zero]], %[[zero]]
+func.func @mulsiExtendedZeroLhs(%arg0: i32) -> (i32, i32) {
+ %zero = arith.constant 0 : i32
+ %low, %high = arith.mulsi_extended %zero, %arg0: i32
+ return %low, %high : i32, i32
+}
+
+// CHECK-LABEL: @mulsiExtendedUnusedHigh
+// CHECK-SAME: (%[[ARG:.+]]: i32) -> i32
+// CHECK-NEXT: %[[RES:.+]] = arith.muli %[[ARG]], %[[ARG]] : i32
+// CHECK-NEXT: return %[[RES]]
+func.func @mulsiExtendedUnusedHigh(%arg0: i32) -> i32 {
+ %low, %high = arith.mulsi_extended %arg0, %arg0: i32
+ return %low : i32
+}
+
+// CHECK-LABEL: @mulsiExtendedScalarConstants
+// CHECK-DAG: %[[c27:.+]] = arith.constant 27 : i8
+// CHECK-DAG: %[[c_n3:.+]] = arith.constant -3 : i8
+// CHECK-NEXT: return %[[c27]], %[[c_n3]]
+func.func @mulsiExtendedScalarConstants() -> (i8, i8) {
+ %c57 = arith.constant 57 : i8
+ %c_n13 = arith.constant -13 : i8
+ %low, %high = arith.mulsi_extended %c57, %c_n13: i8
+ return %low, %high : i8, i8
+}
+
+// CHECK-LABEL: @mulsiExtendedVectorConstants
+// CHECK-DAG: %[[cstLo:.+]] = arith.constant dense<[65, 79, 34]> : vector<3xi8>
+// CHECK-DAG: %[[cstHi:.+]] = arith.constant dense<[0, 14, 0]> : vector<3xi8>
+// CHECK-NEXT: return %[[cstLo]], %[[cstHi]]
+func.func @mulsiExtendedVectorConstants() -> (vector<3xi8>, vector<3xi8>) {
+ %cstA = arith.constant dense<[5, 37, -17]> : vector<3xi8>
+ %cstB = arith.constant dense<[13, 99, -2]> : vector<3xi8>
+ %low, %high = arith.mulsi_extended %cstA, %cstB: vector<3xi8>
+ return %low, %high : vector<3xi8>, vector<3xi8>
+}
+
// CHECK-LABEL: @muluiExtendedZeroRhs
// CHECK-NEXT: %[[zero:.+]] = arith.constant 0 : i32
// CHECK-NEXT: return %[[zero]], %[[zero]]
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index a9c0b557fdf01..25c015bfa3ccc 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -97,6 +97,30 @@ func.func @test_muli_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]
return %0 : vector<[8]xi64>
}
+// CHECK-LABEL: test_mulsi_extended
+func.func @test_mulsi_extended(%arg0 : i32, %arg1 : i32) -> i32 {
+ %low, %high = arith.mulsi_extended %arg0, %arg1 : i32
+ return %high : i32
+}
+
+// CHECK-LABEL: test_mulsi_extended_tensor
+func.func @test_mulsi_extended_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) -> tensor<8x8xi64> {
+ %low, %high = arith.mulsi_extended %arg0, %arg1 : tensor<8x8xi64>
+ return %high : tensor<8x8xi64>
+}
+
+// CHECK-LABEL: test_mulsi_extended_vector
+func.func @test_mulsi_extended_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8xi64> {
+ %0:2 = arith.mulsi_extended %arg0, %arg1 : vector<8xi64>
+ return %0#0 : vector<8xi64>
+}
+
+// CHECK-LABEL: test_mulsi_extended_scalable_vector
+func.func @test_mulsi_extended_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
+ %0:2 = arith.mulsi_extended %arg0, %arg1 : vector<[8]xi64>
+ return %0#1 : vector<[8]xi64>
+}
+
// CHECK-LABEL: test_mului_extended
func.func @test_mului_extended(%arg0 : i32, %arg1 : i32) -> i32 {
%low, %high = arith.mului_extended %arg0, %arg1 : i32
More information about the Mlir-commits
mailing list