[Mlir-commits] [mlir] b4bdcea - [mlir][arith] Define mului_extended op

Jakub Kuderski llvmlistbot at llvm.org
Fri Dec 9 14:38:37 PST 2022


Author: Jakub Kuderski
Date: 2022-12-09T17:37:06-05:00
New Revision: b4bdcea2148f8885ac4a109e1cd3c9e67a46c161

URL: https://github.com/llvm/llvm-project/commit/b4bdcea2148f8885ac4a109e1cd3c9e67a46c161
DIFF: https://github.com/llvm/llvm-project/commit/b4bdcea2148f8885ac4a109e1cd3c9e67a46c161.diff

LOG: [mlir][arith] Define mului_extended op

Add conversion to the SPIR-V and LLVM dialects.

This was originally proposed in:
https://discourse.llvm.org/t/rfc-arith-add-extended-multiplication-ops/66869.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D139688

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
    mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
    mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
    mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
    mlir/lib/Dialect/Arith/IR/ArithOps.cpp
    mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
    mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
    mlir/test/Dialect/Arith/canonicalize.mlir
    mlir/test/Dialect/Arith/ops.mlir
    mlir/tools/mlir-tblgen/RewriterGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 0c54acac59a36..f02a2f89a987d 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -291,6 +291,49 @@ def Arith_MulIOp : Arith_TotalIntBinaryOp<"muli", [Commutative]> {
   let hasFolder = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// MulUIExtendedOp
+//===----------------------------------------------------------------------===//
+
+def Arith_MulUIExtendedOp : Arith_Op<"mului_extended", [Pure, Commutative,
+    AllTypesMatch<["lhs", "rhs", "low", "high"]>]> {
+  let summary = [{
+    extended unsigned integer multiplication operation
+  }];
+
+  let description = [{
+    Performs (2*N)-bit multiplication on zero-extended operands. Returns two
+    N-bit results: the low and the high halves of the product. The low half has
+    the same value as the result of regular multiplication `arith.muli` with
+    the same operands.
+
+    Example:
+
+    ```mlir
+    // Scalar multiplication.
+    %low, %high = arith.mului_extended %a, %b : i32
+
+    // Vector element-wise multiplication.
+    %c:2 = arith.mului_extended %d, %e : vector<4xi32>
+
+    // Tensor element-wise multiplication.
+    %x:2 = arith.mului_extended %y, %z : tensor<4x?xi8>
+    ```
+  }];
+
+  let arguments = (ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs);
+  let results = (outs SignlessIntegerLike:$low, SignlessIntegerLike:$high);
+
+  let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs)";
+
+  let hasFolder = 1;
+  let hasCanonicalizer = 1;
+
+  let extraClassDeclaration = [{
+    ::llvm::Optional<::llvm::SmallVector<int64_t, 4>> getShapeForUnroll();
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // DivUIOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 0289bea88b504..9483727f68508 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -142,6 +142,15 @@ struct AddUIExtendedOpLowering
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+struct MulUIExtendedOpLowering
+    : public ConvertOpToLLVMPattern<arith::MulUIExtendedOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::MulUIExtendedOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 struct CmpIOpLowering : public ConvertOpToLLVMPattern<arith::CmpIOp> {
   using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
 
@@ -261,6 +270,67 @@ LogicalResult AddUIExtendedOpLowering::matchAndRewrite(
                                      "ND vector types are not supported yet");
 }
 
+//===----------------------------------------------------------------------===//
+// MulUIExtendedOpLowering
+//===----------------------------------------------------------------------===//
+
+LogicalResult MulUIExtendedOpLowering::matchAndRewrite(
+    arith::MulUIExtendedOp op, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  Type resultType = adaptor.getLhs().getType();
+
+  if (!LLVM::isCompatibleType(resultType))
+    return failure();
+
+  Location loc = op.getLoc();
+
+  // Handle the scalar and 1D vector cases. Because LLVM does not have a
+  // matching extended multiplication intrinsic, perform regular multiplication
+  // on operands zero-extended to i(2*N) bits, and truncate the results back to
+  // iN types.
+  if (!resultType.isa<LLVM::LLVMArrayType>()) {
+    Type wideType;
+    // Shift amount necessary to extract the high bits from widened result.
+    Attribute shiftValAttr;
+
+    if (auto intTy = resultType.dyn_cast<IntegerType>()) {
+      unsigned resultBitwidth = intTy.getWidth();
+      wideType = rewriter.getIntegerType(resultBitwidth * 2);
+      shiftValAttr = rewriter.getIntegerAttr(wideType, resultBitwidth);
+    } else {
+      auto vecTy = resultType.cast<VectorType>();
+      unsigned resultBitwidth = vecTy.getElementTypeBitWidth();
+      wideType = VectorType::get(vecTy.getShape(),
+                                 rewriter.getIntegerType(resultBitwidth * 2));
+      shiftValAttr = SplatElementsAttr::get(
+          wideType, APInt(resultBitwidth * 2, resultBitwidth));
+    }
+    assert(LLVM::isCompatibleType(wideType) &&
+           "LLVM dialect should support all signless integer types");
+
+    Value lhsExt =
+        rewriter.create<LLVM::ZExtOp>(loc, wideType, adaptor.getLhs());
+    Value rhsExt =
+        rewriter.create<LLVM::ZExtOp>(loc, wideType, adaptor.getRhs());
+    Value mulExt = rewriter.create<LLVM::MulOp>(loc, wideType, lhsExt, rhsExt);
+
+    // Split the 2*N-bit wide result into two N-bit values.
+    Value low = rewriter.create<LLVM::TruncOp>(loc, resultType, mulExt);
+    Value shiftVal = rewriter.create<LLVM::ConstantOp>(loc, shiftValAttr);
+    Value highExt = rewriter.create<LLVM::LShrOp>(loc, mulExt, shiftVal);
+    Value high = rewriter.create<LLVM::TruncOp>(loc, resultType, highExt);
+
+    rewriter.replaceOp(op, {low, high});
+    return success();
+  }
+
+  if (!resultType.isa<VectorType>())
+    return rewriter.notifyMatchFailure(op, "expected vector result type");
+
+  return rewriter.notifyMatchFailure(op,
+                                     "ND vector types are not supported yet");
+}
+
 //===----------------------------------------------------------------------===//
 // CmpIOpLowering
 //===----------------------------------------------------------------------===//
@@ -397,6 +467,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
     MinUIOpLowering,
     MulFOpLowering,
     MulIOpLowering,
+    MulUIExtendedOpLowering,
     NegFOpLowering,
     OrIOpLowering,
     RemFOpLowering,

diff  --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index a127dd8f4e8a6..03e877f56578d 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -223,6 +223,16 @@ class AddUIExtendedOpPattern final
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+/// Converts arith.mului_extended to spirv.UMulExtended.
+class MulUIExtendedOpPattern final
+    : public OpConversionPattern<arith::MulUIExtendedOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(arith::MulUIExtendedOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 /// Converts arith.select to spirv.Select.
 class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {
 public:
@@ -944,6 +954,26 @@ LogicalResult AddUIExtendedOpPattern::matchAndRewrite(
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// MulUIExtendedOpPattern
+//===----------------------------------------------------------------------===//
+
+LogicalResult MulUIExtendedOpPattern::matchAndRewrite(
+    arith::MulUIExtendedOp op, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  Location loc = op->getLoc();
+  Value result = rewriter.create<spirv::UMulExtendedOp>(loc, adaptor.getLhs(),
+                                                        adaptor.getRhs());
+
+  Value low = rewriter.create<spirv::CompositeExtractOp>(loc, result,
+                                                         llvm::makeArrayRef(0));
+  Value high = rewriter.create<spirv::CompositeExtractOp>(
+      loc, result, llvm::makeArrayRef(1));
+
+  rewriter.replaceOp(op, {low, high});
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // SelectOpPattern
 //===----------------------------------------------------------------------===//
@@ -1040,7 +1070,7 @@ void mlir::arith::populateArithToSPIRVPatterns(
     TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
     CmpIOpBooleanPattern, CmpIOpPattern,
     CmpFOpNanNonePattern, CmpFOpPattern,
-    AddUIExtendedOpPattern, SelectOpPattern,
+    AddUIExtendedOpPattern, MulUIExtendedOpPattern, SelectOpPattern,
 
     MinMaxFOpPattern<arith::MaxFOp, spirv::GLFMaxOp>,
     MinMaxFOpPattern<arith::MinFOp, spirv::GLFMinOp>,

diff  --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index a12c1fe5c6326..7637369c4e552 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -111,6 +111,17 @@ def SubISubILHSRHSLHS :
     Pat<(Arith_SubIOp:$res (Arith_SubIOp $x, $y), $x),
         (Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y)>;
 
+//===----------------------------------------------------------------------===//
+// MulUIExtendedOp
+//===----------------------------------------------------------------------===//
+
+// mului_extended(x, y) -> [muli(x, y), x], when the `high` result is unused.
+// Since the `high` result it not used, any replacement value will do.
+def MulUIExtendedToMulI :
+    Pattern<(Arith_MulUIExtendedOp:$res $x, $y),
+        [(Arith_MulIOp $x, $y), (replaceWithValue $x)],
+        [(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;
+
 //===----------------------------------------------------------------------===//
 // XOrIOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 29bce4def8745..fde37d6bd69d4 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -237,7 +237,7 @@ static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand) {
 LogicalResult
 arith::AddUIExtendedOp::fold(ArrayRef<Attribute> operands,
                              SmallVectorImpl<OpFoldResult> &results) {
-  auto overflowTy = getOverflow().getType();
+  Type overflowTy = getOverflow().getType();
   // addui_extended(x, 0) -> x, false
   if (matchPattern(getRhs(), m_Zero())) {
     auto overflowZero = APInt::getZero(1);
@@ -350,6 +350,61 @@ OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) {
       operands, [](const APInt &a, const APInt &b) { return a * b; });
 }
 
+//===----------------------------------------------------------------------===//
+// MulUIExtendedOp
+//===----------------------------------------------------------------------===//
+
+Optional<SmallVector<int64_t, 4>> arith::MulUIExtendedOp::getShapeForUnroll() {
+  if (auto vt = getType(0).dyn_cast<VectorType>())
+    return llvm::to_vector<4>(vt.getShape());
+  return std::nullopt;
+}
+
+LogicalResult
+arith::MulUIExtendedOp::fold(ArrayRef<Attribute> operands,
+                             SmallVectorImpl<OpFoldResult> &results) {
+  // mului_extended(x, 0) -> 0, 0
+  if (matchPattern(getRhs(), m_Zero())) {
+    Attribute zero = operands[1];
+    results.push_back(zero);
+    results.push_back(zero);
+    return success();
+  }
+
+  // mului_extended(x, 1) -> x, 0
+  if (matchPattern(getRhs(), m_One())) {
+    Builder builder(getContext());
+    Attribute zero = builder.getZeroAttr(getLhs().getType());
+    results.push_back(getLhs());
+    results.push_back(zero);
+    return success();
+  }
+
+  // mului_extended(cst_a, cst_b) -> cst_low, cst_high
+  if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
+          operands, [](const APInt &a, const APInt &b) { return a * b; })) {
+    // Invoke the constant fold helper again to calculate the 'high' result.
+    Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
+        operands, [](const APInt &a, const APInt &b) {
+          unsigned bitWidth = a.getBitWidth();
+          APInt fullProduct = a.zext(bitWidth * 2) * b.zext(bitWidth * 2);
+          return fullProduct.extractBits(bitWidth, bitWidth);
+        });
+    assert(highAttr && "Unexpected constant-folding failure");
+
+    results.push_back(lowAttr);
+    results.push_back(highAttr);
+    return success();
+  }
+
+  return failure();
+}
+
+void arith::MulUIExtendedOp::getCanonicalizationPatterns(
+    RewritePatternSet &patterns, MLIRContext *context) {
+  patterns.add<MulUIExtendedToMulI>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // DivUIOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index cf207c283b0d4..8abf8130744a5 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -379,6 +379,38 @@ func.func @addui_extended_vector1d(%arg0: vector<3xi16>, %arg1: vector<3xi16>) -
 
 // -----
 
+// CHECK-LABEL: @mului_extended_scalar
+// CHECK-SAME:    ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> (i32, i32)
+func.func @mului_extended_scalar(%arg0: i32, %arg1: i32) -> (i32, i32) {
+  // CHECK-NEXT: [[LHS:%.+]]  = llvm.zext [[ARG0]] : i32 to i64
+  // CHECK-NEXT: [[RHS:%.+]]  = llvm.zext [[ARG1]] : i32 to i64
+  // CHECK-NEXT: [[MUL:%.+]]  = llvm.mul [[LHS]], [[RHS]] : i64
+  // CHECK-NEXT: [[LOW:%.+]]  = llvm.trunc [[MUL]] : i64 to i32
+  // CHECK-NEXT: [[C32:%.+]]  = llvm.mlir.constant(32 : i64) : i64
+  // CHECK-NEXT: [[SHL:%.+]]  = llvm.lshr [[MUL]], [[C32]] : i64
+  // CHECK-NEXT: [[HIGH:%.+]] = llvm.trunc [[SHL]] : i64 to i32
+  %low, %high = arith.mului_extended %arg0, %arg1 : i32
+  // CHECK-NEXT: return [[LOW]], [[HIGH]] : i32, i32
+  return %low, %high : i32, i32
+}
+
+// CHECK-LABEL: @mului_extended_vector1d
+// CHECK-SAME:    ([[ARG0:%.+]]: vector<3xi64>, [[ARG1:%.+]]: vector<3xi64>) -> (vector<3xi64>, vector<3xi64>)
+func.func @mului_extended_vector1d(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> (vector<3xi64>, vector<3xi64>) {
+  // CHECK-NEXT: [[LHS:%.+]]  = llvm.zext [[ARG0]] : vector<3xi64> to vector<3xi128>
+  // CHECK-NEXT: [[RHS:%.+]]  = llvm.zext [[ARG1]] : vector<3xi64> to vector<3xi128>
+  // CHECK-NEXT: [[MUL:%.+]]  = llvm.mul [[LHS]], [[RHS]] : vector<3xi128>
+  // CHECK-NEXT: [[LOW:%.+]]  = llvm.trunc [[MUL]] : vector<3xi128> to vector<3xi64>
+  // CHECK-NEXT: [[C64:%.+]]  = llvm.mlir.constant(dense<64> : vector<3xi128>) : vector<3xi128>
+  // CHECK-NEXT: [[SHL:%.+]]  = llvm.lshr [[MUL]], [[C64]] : vector<3xi128>
+  // CHECK-NEXT: [[HIGH:%.+]] = llvm.trunc [[SHL]] : vector<3xi128> to vector<3xi64>
+  %low, %high = arith.mului_extended %arg0, %arg1 : vector<3xi64>
+  // CHECK-NEXT: return [[LOW]], [[HIGH]] : vector<3xi64>, vector<3xi64>
+  return %low, %high : vector<3xi64>, vector<3xi64>
+}
+
+// -----
+
 // CHECK-LABEL: func @cmpf_2dvector(
 func.func @cmpf_2dvector(%arg0 : vector<4x3xf32>, %arg1 : vector<4x3xf32>) {
   // CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast

diff  --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 938bafa357cf3..c2f642e3fb264 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -99,6 +99,29 @@ func.func @int32_vector_addui_extended(%lhs: vector<4xi32>, %rhs: vector<4xi32>)
   return %sum, %overflow : vector<4xi32>, vector<4xi1>
 }
 
+// Check extended unsigned integer multiplication conversions.
+// CHECK-LABEL: @int32_scalar_mului_extended
+// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32)
+func.func @int32_scalar_mului_extended(%lhs: i32, %rhs: i32) -> (i32, i32) {
+  // CHECK-NEXT: %[[MUL:.+]]   = spirv.UMulExtended %[[LHS]], %[[RHS]] : !spirv.struct<(i32, i32)>
+  // CHECK-DAG:  %[[LOW:.+]]   = spirv.CompositeExtract %[[MUL]][0 : i32] : !spirv.struct<(i32, i32)>
+  // CHECK-DAG:  %[[HIGH:.+]]  = spirv.CompositeExtract %[[MUL]][1 : i32] : !spirv.struct<(i32, i32)>
+  // CHECK-NEXT: return %[[LOW]], %[[HIGH]] : i32, i32
+  %low, %high = arith.mului_extended %lhs, %rhs: i32
+  return %low, %high : i32, i32
+}
+
+// CHECK-LABEL: @int32_vector_mului_extended
+// CHECK-SAME: (%[[LHS:.+]]: vector<4xi32>, %[[RHS:.+]]: vector<4xi32>)
+func.func @int32_vector_mului_extended(%lhs: vector<4xi32>, %rhs: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) {
+  // CHECK-NEXT: %[[MUL:.+]]   = spirv.UMulExtended %[[LHS]], %[[RHS]] : !spirv.struct<(vector<4xi32>, vector<4xi32>)>
+  // CHECK-DAG:  %[[LOW:.+]]   = spirv.CompositeExtract %[[MUL]][0 : i32] : !spirv.struct<(vector<4xi32>, vector<4xi32>)>
+  // CHECK-DAG:  %[[HIGH:.+]]  = spirv.CompositeExtract %[[MUL]][1 : i32] : !spirv.struct<(vector<4xi32>, vector<4xi32>)>
+  // CHECK-NEXT: return %[[LOW]], %[[HIGH]] : vector<4xi32>, vector<4xi32>
+  %low, %high = arith.mului_extended %lhs, %rhs: vector<4xi32>
+  return %low, %high : vector<4xi32>, vector<4xi32>
+}
+
 // Check float unary operation conversions.
 // CHECK-LABEL: @float32_unary_scalar
 func.func @float32_unary_scalar(%arg0: f32) {

diff  --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 6a1c0feb9025d..1280a2adece89 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -734,6 +734,104 @@ func.func @adduiExtendedConstantsSplatVector() -> (vector<4xi32>, vector<4xi1>)
   return %sum, %overflow : vector<4xi32>, vector<4xi1>
 }
 
+// CHECK-LABEL: @muluiExtendedZeroRhs
+//  CHECK-NEXT:   %[[zero:.+]] = arith.constant 0 : i32
+//  CHECK-NEXT:   return %[[zero]], %[[zero]]
+func.func @muluiExtendedZeroRhs(%arg0: i32) -> (i32, i32) {
+  %zero = arith.constant 0 : i32
+  %low, %high = arith.mului_extended %arg0, %zero: i32
+  return %low, %high : i32, i32
+}
+
+// CHECK-LABEL: @muluiExtendedZeroRhsSplat
+//  CHECK-NEXT:   %[[zero:.+]] = arith.constant dense<0> : vector<3xi32>
+//  CHECK-NEXT:   return %[[zero]], %[[zero]]
+func.func @muluiExtendedZeroRhsSplat(%arg0: vector<3xi32>) -> (vector<3xi32>, vector<3xi32>) {
+  %zero = arith.constant dense<0> : vector<3xi32>
+  %low, %high = arith.mului_extended %arg0, %zero: vector<3xi32>
+  return %low, %high : vector<3xi32>, vector<3xi32>
+}
+
+// CHECK-LABEL: @muluiExtendedZeroLhs
+//  CHECK-NEXT:   %[[zero:.+]] = arith.constant 0 : i32
+//  CHECK-NEXT:   return %[[zero]], %[[zero]]
+func.func @muluiExtendedZeroLhs(%arg0: i32) -> (i32, i32) {
+  %zero = arith.constant 0 : i32
+  %low, %high = arith.mului_extended %zero, %arg0: i32
+  return %low, %high : i32, i32
+}
+
+// CHECK-LABEL: @muluiExtendedOneRhs
+//  CHECK-SAME:   (%[[ARG:.+]]: i32) -> (i32, i32)
+//  CHECK-NEXT:   %[[zero:.+]] = arith.constant 0 : i32
+//  CHECK-NEXT:   return %[[ARG]], %[[zero]]
+func.func @muluiExtendedOneRhs(%arg0: i32) -> (i32, i32) {
+  %zero = arith.constant 1 : i32
+  %low, %high = arith.mului_extended %arg0, %zero: i32
+  return %low, %high : i32, i32
+}
+
+// CHECK-LABEL: @muluiExtendedOneRhsSplat
+//  CHECK-SAME:   (%[[ARG:.+]]: vector<3xi32>) -> (vector<3xi32>, vector<3xi32>)
+//  CHECK-NEXT:   %[[zero:.+]] = arith.constant dense<0> : vector<3xi32>
+//  CHECK-NEXT:   return %[[ARG]], %[[zero]]
+func.func @muluiExtendedOneRhsSplat(%arg0: vector<3xi32>) -> (vector<3xi32>, vector<3xi32>) {
+  %zero = arith.constant dense<1> : vector<3xi32>
+  %low, %high = arith.mului_extended %arg0, %zero: vector<3xi32>
+  return %low, %high : vector<3xi32>, vector<3xi32>
+}
+
+// CHECK-LABEL: @muluiExtendedOneLhs
+//  CHECK-SAME:   (%[[ARG:.+]]: i32) -> (i32, i32)
+//  CHECK-NEXT:   %[[zero:.+]] = arith.constant 0 : i32
+//  CHECK-NEXT:   return %[[ARG]], %[[zero]]
+func.func @muluiExtendedOneLhs(%arg0: i32) -> (i32, i32) {
+  %zero = arith.constant 1 : i32
+  %low, %high = arith.mului_extended %zero, %arg0: i32
+  return %low, %high : i32, i32
+}
+
+// CHECK-LABEL: @muluiExtendedUnusedHigh
+//  CHECK-SAME:   (%[[ARG:.+]]: i32) -> i32
+//  CHECK-NEXT:   %[[RES:.+]] = arith.muli %[[ARG]], %[[ARG]] : i32
+//  CHECK-NEXT:   return %[[RES]]
+func.func @muluiExtendedUnusedHigh(%arg0: i32) -> i32 {
+  %low, %high = arith.mului_extended %arg0, %arg0: i32
+  return %low : i32
+}
+
+// This shouldn't be folded.
+// CHECK-LABEL: @muluiExtendedUnusedLow
+//  CHECK-SAME:   (%[[ARG:.+]]: i32) -> i32
+//  CHECK-NEXT:   %[[LOW:.+]], %[[HIGH:.+]] = arith.mului_extended %[[ARG]], %[[ARG]] : i32
+//  CHECK-NEXT:   return %[[HIGH]]
+func.func @muluiExtendedUnusedLow(%arg0: i32) -> i32 {
+  %low, %high = arith.mului_extended %arg0, %arg0: i32
+  return %high : i32
+}
+
+// CHECK-LABEL: @muluiExtendedScalarConstants
+//  CHECK-DAG:    %[[c157:.+]] = arith.constant -99 : i8
+//  CHECK-DAG:    %[[c29:.+]] = arith.constant 29 : i8
+//  CHECK-NEXT:   return %[[c157]], %[[c29]]
+func.func @muluiExtendedScalarConstants() -> (i8, i8) {
+  %c57 = arith.constant 57 : i8
+  %c133 = arith.constant 133 : i8
+  %low, %high = arith.mului_extended %c57, %c133: i8 // = 7581
+  return %low, %high : i8, i8
+}
+
+// CHECK-LABEL: @muluiExtendedVectorConstants
+//  CHECK-DAG:    %[[cstLo:.+]] = arith.constant dense<[65, 79, 1]> : vector<3xi8>
+//  CHECK-DAG:    %[[cstHi:.+]] = arith.constant dense<[0, 14, -2]> : vector<3xi8>
+//  CHECK-NEXT:   return %[[cstLo]], %[[cstHi]]
+func.func @muluiExtendedVectorConstants() -> (vector<3xi8>, vector<3xi8>) {
+  %cstA = arith.constant dense<[5, 37, 255]> : vector<3xi8>
+  %cstB = arith.constant dense<[13, 99, 255]> : vector<3xi8>
+  %low, %high = arith.mului_extended %cstA, %cstB: vector<3xi8>
+  return %low, %high : vector<3xi8>, vector<3xi8>
+}
+
 // CHECK-LABEL: @notCmpEQ
 //       CHECK:   %[[cres:.+]] = arith.cmpi ne, %arg0, %arg1 : i8
 //       CHECK:   return %[[cres]]

diff  --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index 99a777d3d5f79..a9c0b557fdf01 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -97,6 +97,30 @@ func.func @test_muli_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]
   return %0 : vector<[8]xi64>
 }
 
+// CHECK-LABEL: test_mului_extended
+func.func @test_mului_extended(%arg0 : i32, %arg1 : i32) -> i32 {
+  %low, %high = arith.mului_extended %arg0, %arg1 : i32
+  return %high : i32
+}
+
+// CHECK-LABEL: test_mului_extended_tensor
+func.func @test_mului_extended_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) -> tensor<8x8xi64> {
+  %low, %high = arith.mului_extended %arg0, %arg1 : tensor<8x8xi64>
+  return %high : tensor<8x8xi64>
+}
+
+// CHECK-LABEL: test_mului_extended_vector
+func.func @test_mului_extended_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8xi64> {
+  %0:2 = arith.mului_extended %arg0, %arg1 : vector<8xi64>
+  return %0#0 : vector<8xi64>
+}
+
+// CHECK-LABEL: test_mului_extended_scalable_vector
+func.func @test_mului_extended_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
+  %0:2 = arith.mului_extended %arg0, %arg1 : vector<[8]xi64>
+  return %0#1 : vector<[8]xi64>
+}
+
 // CHECK-LABEL: test_divui
 func.func @test_divui(%arg0 : i64, %arg1 : i64) -> i64 {
   %0 = arith.divui %arg0, %arg1 : i64

diff  --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index cdb1536ce65bb..a13429ea8bfb3 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -1044,7 +1044,8 @@ void PatternEmitter::emitRewriteLogic() {
   }
 
   if (offsets.front() > 0) {
-    const char error[] = "no enough values generated to replace the matched op";
+    const char error[] =
+        "not enough values generated to replace the matched op";
     PrintFatalError(loc, error);
   }
 


        


More information about the Mlir-commits mailing list