[Mlir-commits] [mlir] Cherry-pick https://github.com/llvm/llvm-project/pull/70218 (PR #70219)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 25 08:49:57 PDT 2023


https://github.com/bjacob created https://github.com/llvm/llvm-project/pull/70219

None

>From 1d514d74f885d00ad9eafd38efdb611ffc04b2d2 Mon Sep 17 00:00:00 2001
From: Stella Laurenzo <stellaraccident at gmail.com>
Date: Mon, 16 Oct 2023 17:59:39 -0700
Subject: [PATCH 1/3] Revert "[mlir][tosa][linalg] Apply direct tosa -> linalg
 Conv2D lowering (#68304)"

This reverts commit e29a253c9ebaded53a823def985364392c4ba4ec.

Breaking TFLite mobilenet test. Needs triage.
---
 .../Linalg/IR/LinalgNamedStructuredOps.yaml   | 137 ------------------
 .../TosaToLinalg/TosaToLinalgNamed.cpp        |  43 +++---
 .../linalg/opdsl/ops/core_named_ops.py        |  30 ----
 .../TosaToLinalg/tosa-to-linalg-named.mlir    |  20 ++-
 4 files changed, 34 insertions(+), 196 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index cd64b813c11e532..44bcbbab2bbe9de 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -2575,143 +2575,6 @@ structured_op: !LinalgStructuredOpConfig
                     - !ScalarExpression
                       scalar_arg: KZp
 --- !LinalgOpConfig
-metadata: !LinalgOpMetadata
-  name: conv_2d_nhwc_fhwc_q
-  cpp_class_name: Conv2DNhwcFhwcQOp
-  doc: |-
-    Performs 2-D convolution with zero point offsets.
-
-    Layout:
-      * Input: NHWC.
-      * Kernel: FHWC.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output. This includes the zero
-    point offsets common to quantized operations.
-  implements:
-  - LinalgConvolutionOpInterface
-structured_op: !LinalgStructuredOpConfig
-  args:
-  - !LinalgOperandDefConfig
-    name: I
-    kind: input_tensor
-    type_var: T1
-    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0,
-      s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9)>
-  - !LinalgOperandDefConfig
-    name: K
-    kind: input_tensor
-    type_var: T2
-    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s10,
-      s3, s7, s9)>
-  - !LinalgOperandDefConfig
-    name: IZp
-    kind: scalar
-    type_var: I32
-  - !LinalgOperandDefConfig
-    name: KZp
-    kind: scalar
-    type_var: I32
-  - !LinalgOperandDefConfig
-    name: O
-    kind: output_tensor
-    type_var: U
-    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0,
-      s1, s5, s10)>
-  - !LinalgOperandDefConfig
-    name: strides
-    kind: index_attr
-    index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] ->
-      (s2, s6)>
-    default_indices:
-    - 1
-    - 1
-  - !LinalgOperandDefConfig
-    name: dilations
-    kind: index_attr
-    index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] ->
-      (s4, s8)>
-    default_indices:
-    - 1
-    - 1
-  indexing_maps: !LinalgIndexingMapsConfig
-    static_indexing_maps:
-    - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
-      s9, s10] -> (d0, d1 * s2 + d4 * s4, d2 * s6 + d5 * s8, d6)>
-    - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
-      s9, s10] -> (d3, d4, d5, d6)>
-    - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
-      s9, s10] -> ()>
-    - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
-      s9, s10] -> ()>
-    - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
-      s9, s10] -> (d0, d1, d2, d3)>
-  iterator_types:
-  - parallel
-  - parallel
-  - parallel
-  - parallel
-  - reduction
-  - reduction
-  - reduction
-  assignments:
-  - !ScalarAssign
-    arg: O
-    value: !ScalarExpression
-      scalar_fn:
-        kind: binary
-        fn_name: add
-        operands:
-        - !ScalarExpression
-          scalar_arg: O
-        - !ScalarExpression
-          scalar_fn:
-            kind: binary
-            fn_name: mul
-            operands:
-            - !ScalarExpression
-              scalar_fn:
-                kind: binary
-                fn_name: sub
-                operands:
-                - !ScalarExpression
-                  scalar_fn:
-                    kind: type
-                    fn_name: cast_signed
-                    type_var: U
-                    operands:
-                    - !ScalarExpression
-                      scalar_arg: I
-                - !ScalarExpression
-                  scalar_fn:
-                    kind: type
-                    fn_name: cast_signed
-                    type_var: U
-                    operands:
-                    - !ScalarExpression
-                      scalar_arg: IZp
-            - !ScalarExpression
-              scalar_fn:
-                kind: binary
-                fn_name: sub
-                operands:
-                - !ScalarExpression
-                  scalar_fn:
-                    kind: type
-                    fn_name: cast_signed
-                    type_var: U
-                    operands:
-                    - !ScalarExpression
-                      scalar_arg: K
-                - !ScalarExpression
-                  scalar_fn:
-                    kind: type
-                    fn_name: cast_signed
-                    type_var: U
-                    operands:
-                    - !ScalarExpression
-                      scalar_arg: KZp
---- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: conv_2d_nchw_fchw
   cpp_class_name: Conv2DNchwFchwOp
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index ee8f52deadbd152..7ef1374c93ba86e 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -248,28 +248,25 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
     pad.resize(pad.size() + 2, 0);
     input = applyPad(loc, input, pad, zeroAttr, rewriter);
 
-    // For Conv3D transpose the kernel to match dimension ordering of the linalg
-    // convolution operation. Conv2D has a 1-1 mapping in linalg so better to
-    // map directly and then transpose later if desired.
-    if (5 == inputTy.getRank()) {
-      // TODO(suderman): See if this can be efficiently folded - check whether
-      // the input is used anywhere else, if not fold the constant.
-      SmallVector<int64_t> weightPerm;
-      for (int i = 1; i < resultTy.getRank(); i++)
-        weightPerm.push_back(i);
-      weightPerm.push_back(0);
-
-      SmallVector<int64_t> newWeightShape;
-      for (auto dim : weightPerm)
-        newWeightShape.push_back(weightShape[dim]);
-      auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm);
-      Value weightPermValue =
-          rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
-      Type newWeightTy =
-          RankedTensorType::get(newWeightShape, weightTy.getElementType());
-      weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
-                                                  weightPermValue);
-    }
+    // Transpose the kernel to match dimension ordering of the linalg
+    // convolution operation.
+    // TODO(suderman): See if this can be efficiently folded - check whether
+    // the input is used anywhere else, if not fold the constant.
+    SmallVector<int64_t> weightPerm;
+    for (int i = 1; i < resultTy.getRank(); i++)
+      weightPerm.push_back(i);
+    weightPerm.push_back(0);
+
+    SmallVector<int64_t> newWeightShape;
+    for (auto dim : weightPerm)
+      newWeightShape.push_back(weightShape[dim]);
+    auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm);
+    Value weightPermValue =
+        rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
+    Type newWeightTy =
+        RankedTensorType::get(newWeightShape, weightTy.getElementType());
+    weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
+                                                weightPermValue);
 
     auto resultZeroAttr = rewriter.getZeroAttr(resultETy);
     Value emptyTensor = rewriter.create<tensor::EmptyOp>(
@@ -980,7 +977,7 @@ void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
     RewritePatternSet *patterns) {
   patterns->add<
       // clang-format off
-      ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcFhwcOp, linalg::Conv2DNhwcFhwcQOp>,
+      ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp, linalg::Conv2DNhwcHwcfQOp>,
       ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>,
       DepthwiseConvConverter,
       MatMulConverter,
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 19734a80a107bfe..e8bdb9180471927 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -702,36 +702,6 @@ def conv_2d_nhwc_hwcf_q(
     ) * (TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f]) - TypeFn.cast_signed(U, KZp))
 
 
- at linalg_structured_op
-def conv_2d_nhwc_fhwc_q(
-    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
-    K=TensorDef(T2, S.F, S.KH, S.KW, S.C),
-    IZp=ScalarDef(I32),
-    KZp=ScalarDef(I32),
-    O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
-    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
-    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
-):
-    """Performs 2-D convolution with zero point offsets.
-
-    Layout:
-      * Input: NHWC.
-      * Kernel: FHWC.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output. This includes the zero
-    point offsets common to quantized operations.
-    """
-    implements(ConvolutionOpInterface)
-    domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
-    O[D.n, D.oh, D.ow, D.f] += (
-        TypeFn.cast_signed(
-            U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
-        )
-        - TypeFn.cast_signed(U, IZp)
-    ) * (TypeFn.cast_signed(U, K[D.f, D.kh, D.kw, D.c]) - TypeFn.cast_signed(U, KZp))
-
-
 @linalg_structured_op
 def conv_2d_nchw_fchw(
     I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW),
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index b601bfb28a4f280..bf970c84832e9e5 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -363,11 +363,13 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
 
 // CHECK-LABEL: @conv2d_i8
 func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi8>, %bias: tensor<28xi8>) -> () {
+  // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]>
+  // CHECK: %[[W:.+]] = tosa.transpose %arg1, %[[PERM]]
   // CHECK: %[[M_IN:.+]] = tensor.empty()
   // CHECK: %[[CST:.+]] = arith.constant 0
   // CHECK: %[[FILL:.+]] = linalg.fill
   // CHECK: %[[B_IN:.+]] = tensor.empty()
-  // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, i32, i32) outs(%[[FILL]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
+  // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]], %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8>, i32, i32) outs(%[[FILL]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
   // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xi8>, tensor<1x45x40x28xi32>) outs(%[[B_IN]] : tensor<1x45x40x28xi32>)
   // CHECK:   arith.extsi
   // CHECK:   arith.addi
@@ -383,11 +385,13 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi
 
 // CHECK-LABEL: @conv2d_f32
 func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
+  // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]>
+  // CHECK: %[[W:.+]] = tosa.transpose %arg1, %[[PERM]]
   // CHECK: %[[M_IN:.+]] = tensor.empty()
   // CHECK: %[[CST:.+]] = arith.constant 0
   // CHECK: %[[FILL:.+]] = linalg.fill
   // CHECK: %[[B_IN:.+]] = tensor.empty()
-  // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor<1x45x40x28xf32>)
+  // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]] : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%[[FILL]] : tensor<1x45x40x28xf32>)
   // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x45x40x28xf32>) outs(%[[B_IN]] : tensor<1x45x40x28xf32>)
   // CHECK:   arith.addf
   // CHECK:   linalg.yield
@@ -404,11 +408,13 @@ func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27
 func.func @conv2d_dyn(%input: tensor<?x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
   // CHECK: %[[C0:.+]] = arith.constant 0
   // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]]
+  // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]>
+  // CHECK: %[[W:.+]] = tosa.transpose %arg1, %[[PERM]]
   // CHECK: %[[M_IN:.+]] = tensor.empty(%[[BATCH]])
   // CHECK: %[[CST:.+]] = arith.constant 0
   // CHECK: %[[FILL:.+]] = linalg.fill
   // CHECK: %[[B_IN:.+]] = tensor.empty(%[[BATCH]])
-  // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor<?x45x40x28xf32>)
+  // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]] : tensor<?x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%[[FILL]] : tensor<?x45x40x28xf32>)
   // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<?x45x40x28xf32>) outs(%[[B_IN]] : tensor<?x45x40x28xf32>)
   // CHECK:   %[[ADD:.+]] = arith.addf
   // CHECK:   linalg.yield %[[ADD]] : f32
@@ -462,11 +468,13 @@ func.func @conv2d_dyn_w_h(%input: tensor<1x?x?x27xf32>, %weights: tensor<28x3x3x
   // CHECK: %[[W_OUT:.+]] = arith.addi %[[DIVIDED_0]], %[[ONE_0]] : index
 
   // Running convolution
+  // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]>
+  // CHECK: %[[WEIGHT:.+]] = tosa.transpose %arg1, %[[PERM]]
   // CHECK: %[[M_IN:.+]] = tensor.empty(%[[H_OUT]], %[[W_OUT]])
   // CHECK: %[[CST:.+]] = arith.constant 0
   // CHECK: %[[FILL:.+]] = linalg.fill
   // CHECK: %[[B_IN:.+]] = tensor.empty(%[[H_OUT]], %[[W_OUT]])
-  // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor<1x?x?x28xf32>)
+  // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[WEIGHT]] : tensor<1x?x?x27xf32>, tensor<3x3x27x28xf32>) outs(%[[FILL]] : tensor<1x?x?x28xf32>)
   // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x?x?x28xf32>) outs(%[[B_IN]] : tensor<1x?x?x28xf32>)
   // CHECK:   %[[ADD:.+]] = arith.addf
   // CHECK:   linalg.yield %[[ADD]] : f32
@@ -481,7 +489,7 @@ func.func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28
   // CHECK: %[[C0:.+]] = arith.constant 0
   // CHECK: tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
   // CHECK:   tensor.yield %[[C0]]
-  // CHECK: linalg.conv_2d_nhwc_fhwc
+  // CHECK: linalg.conv_2d_nhwc_hwcf
   %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x47x40x28xf32>, tensor<28x3x3x28xf32>, tensor<28xf32>) -> tensor<1x45x40x28xf32>
   return
 }
@@ -493,7 +501,7 @@ func.func @conv2d_quant(%arg0 : tensor<1x12x12x1xi8>, %arg1 : tensor<1024x3x3x1x
   // CHECK:   %[[C22:.+]] = arith.constant -22
   // CHECK: tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
   // CHECK:   tensor.yield %[[C22]]
-  // CHECK: linalg.conv_2d_nhwc_fhwc_q
+  // CHECK: linalg.conv_2d_nhwc_hwcf_q
   %0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 1, 1, 1, 1>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = array<i64: 1, 1>} : (tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, tensor<1024xi32>) -> tensor<1x12x12x1024xi32>
   return
 }

>From 6aaa03a0232b89682e0745433bdfb4d785f64a01 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh at nod-labs.com>
Date: Tue, 24 Oct 2023 17:00:53 -0700
Subject: [PATCH 2/3] Revert "Reland "[MLIR][LLVM] Change addressof builders to
 use opaque pointers" (#69292)"

This reverts commit 484668c7597d9198e21332b30d2f15ece536a0bb.
---
 mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp    |  7 +++--
 mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp    |  7 +++--
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td   |  4 +--
 .../Conversion/GPUCommon/GPUOpsLowering.cpp   | 18 ++++++------
 .../Conversion/GPUToNVVM/gpu-to-nvvm.mlir     | 28 ++++++++++---------
 5 files changed, 34 insertions(+), 30 deletions(-)

diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
index f05f1c2dc33881d..684ce37b2398ce2 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
@@ -117,8 +117,8 @@ class PrintOpLowering : public ConversionPattern {
   ///   * `i32 (i8*, ...)`
   static LLVM::LLVMFunctionType getPrintfType(MLIRContext *context) {
     auto llvmI32Ty = IntegerType::get(context, 32);
-    auto llvmPtrTy = LLVM::LLVMPointerType::get(context);
-    auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmPtrTy,
+    auto llvmI8PtrTy = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
+    auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy,
                                                   /*isVarArg=*/true);
     return llvmFnType;
   }
@@ -162,7 +162,8 @@ class PrintOpLowering : public ConversionPattern {
     Value cst0 = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
                                                   builder.getIndexAttr(0));
     return builder.create<LLVM::GEPOp>(
-        loc, LLVM::LLVMPointerType::get(builder.getContext()), global.getType(),
+        loc,
+        LLVM::LLVMPointerType::get(IntegerType::get(builder.getContext(), 8)),
         globalPtr, ArrayRef<Value>({cst0, cst0}));
   }
 };
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
index f05f1c2dc33881d..684ce37b2398ce2 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
@@ -117,8 +117,8 @@ class PrintOpLowering : public ConversionPattern {
   ///   * `i32 (i8*, ...)`
   static LLVM::LLVMFunctionType getPrintfType(MLIRContext *context) {
     auto llvmI32Ty = IntegerType::get(context, 32);
-    auto llvmPtrTy = LLVM::LLVMPointerType::get(context);
-    auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmPtrTy,
+    auto llvmI8PtrTy = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
+    auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy,
                                                   /*isVarArg=*/true);
     return llvmFnType;
   }
@@ -162,7 +162,8 @@ class PrintOpLowering : public ConversionPattern {
     Value cst0 = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
                                                   builder.getIndexAttr(0));
     return builder.create<LLVM::GEPOp>(
-        loc, LLVM::LLVMPointerType::get(builder.getContext()), global.getType(),
+        loc,
+        LLVM::LLVMPointerType::get(IntegerType::get(builder.getContext(), 8)),
         globalPtr, ArrayRef<Value>({cst0, cst0}));
   }
 };
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 2a572ab4de706a3..8745d14c8d48318 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1071,7 +1071,7 @@ def LLVM_AddressOfOp : LLVM_Op<"mlir.addressof",
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       build($_builder, $_state,
-            LLVM::LLVMPointerType::get($_builder.getContext(), global.getAddrSpace()),
+            LLVM::LLVMPointerType::get(global.getType(), global.getAddrSpace()),
             global.getSymName());
       $_state.addAttributes(attrs);
     }]>,
@@ -1079,7 +1079,7 @@ def LLVM_AddressOfOp : LLVM_Op<"mlir.addressof",
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       build($_builder, $_state,
-            LLVM::LLVMPointerType::get($_builder.getContext()), func.getName());
+            LLVM::LLVMPointerType::get(func.getFunctionType()), func.getName());
       $_state.addAttributes(attrs);
     }]>
   ];
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 6d2585aa30ab4c5..96d8fceba706617 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -441,7 +441,7 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
   Location loc = gpuPrintfOp->getLoc();
 
   mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
-  mlir::Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
+  mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8);
 
   // Note: this is the GPUModule op, not the ModuleOp that surrounds it
   // This ensures that global constants and declarations are placed within
@@ -449,7 +449,7 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
   auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
 
   auto vprintfType =
-      LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
+      LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {i8Ptr, i8Ptr});
   LLVM::LLVMFuncOp vprintfDecl =
       getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType);
 
@@ -473,8 +473,7 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
   // Get a pointer to the format string's first element
   Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
   Value stringStart = rewriter.create<LLVM::GEPOp>(
-      loc, getTypeConverter()->getPointerType(globalType), globalType,
-      globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
+      loc, i8Ptr, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
   SmallVector<Type> types;
   SmallVector<Value> args;
   // Promote and pack the arguments into a stack allocation.
@@ -491,17 +490,18 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
   }
   Type structType =
       LLVM::LLVMStructType::getLiteral(gpuPrintfOp.getContext(), types);
+  Type structPtrType = LLVM::LLVMPointerType::get(structType);
   Value one = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(),
                                                 rewriter.getIndexAttr(1));
-  Value tempAlloc =
-      rewriter.create<LLVM::AllocaOp>(loc, ptrType, structType, one,
-                                      /*alignment=*/0);
+  Value tempAlloc = rewriter.create<LLVM::AllocaOp>(loc, structPtrType, one,
+                                                    /*alignment=*/0);
   for (auto [index, arg] : llvm::enumerate(args)) {
     Value ptr = rewriter.create<LLVM::GEPOp>(
-        loc, getTypeConverter()->getPointerType(structType), structType,
-        tempAlloc, ArrayRef<LLVM::GEPArg>{0, index});
+        loc, LLVM::LLVMPointerType::get(arg.getType()), tempAlloc,
+        ArrayRef<LLVM::GEPArg>{0, index});
     rewriter.create<LLVM::StoreOp>(loc, arg, ptr);
   }
+  tempAlloc = rewriter.create<LLVM::BitcastOp>(loc, i8Ptr, tempAlloc);
   std::array<Value, 2> printfArgs = {stringStart, tempAlloc};
 
   rewriter.create<LLVM::CallOp>(loc, vprintfDecl, printfArgs);
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index a8c02e32ef92b6b..391ccd74841dca4 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -542,15 +542,16 @@ gpu.module @test_module_28 {
 gpu.module @test_module_29 {
   // CHECK-DAG: llvm.mlir.global internal constant @[[$PRINT_GLOBAL0:[A-Za-z0-9_]+]]("Hello, world\0A\00")
   // CHECK-DAG: llvm.mlir.global internal constant @[[$PRINT_GLOBAL1:[A-Za-z0-9_]+]]("Hello: %d\0A\00")
-  // CHECK-DAG: llvm.func @vprintf(!llvm.ptr, !llvm.ptr) -> i32
+  // CHECK-DAG: llvm.func @vprintf(!llvm.ptr<i8>, !llvm.ptr<i8>) -> i32
 
   // CHECK-LABEL: func @test_const_printf
   gpu.func @test_const_printf() {
-    // CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL0]] : !llvm.ptr
-    // CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<14 x i8>
+    // CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL0]] : !llvm.ptr<array<14 x i8>>
+    // CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][0, 0] : (!llvm.ptr<array<14 x i8>>) -> !llvm.ptr<i8>
     // CHECK-NEXT: %[[O:.*]] = llvm.mlir.constant(1 : index) : i64
-    // CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<()> : (i64) -> !llvm.ptr
-    // CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ALLOC]]) : (!llvm.ptr, !llvm.ptr) -> i32
+    // CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<()> : (i64) -> !llvm.ptr<struct<()>>
+    // CHECK-NEXT: %[[ARGPTR:.*]] = llvm.bitcast %[[ALLOC]] : !llvm.ptr<struct<()>> to !llvm.ptr<i8>
+    // CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ARGPTR]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> i32
     gpu.printf "Hello, world\n"
     gpu.return
   }
@@ -558,16 +559,17 @@ gpu.module @test_module_29 {
   // CHECK-LABEL: func @test_printf
   // CHECK: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: f32)
   gpu.func @test_printf(%arg0: i32, %arg1: f32) {
-    // CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL1]] : !llvm.ptr
-    // CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<11 x i8>
+    // CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL1]] : !llvm.ptr<array<11 x i8>>
+    // CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][0, 0] : (!llvm.ptr<array<11 x i8>>) -> !llvm.ptr<i8>
     // CHECK-NEXT: %[[EXT:.+]] = llvm.fpext %[[ARG1]] : f32 to f64
     // CHECK-NEXT: %[[O:.*]] = llvm.mlir.constant(1 : index) : i64
-    // CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<(i32, f64)> : (i64) -> !llvm.ptr
-    // CHECK-NEXT: %[[EL0:.*]] = llvm.getelementptr %[[ALLOC]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i32, f64)>
-    // CHECK-NEXT: llvm.store %[[ARG0]], %[[EL0]] : i32, !llvm.ptr
-    // CHECK-NEXT: %[[EL1:.*]] = llvm.getelementptr %[[ALLOC]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i32, f64)>
-    // CHECK-NEXT: llvm.store %[[EXT]], %[[EL1]] : f64, !llvm.ptr
-    // CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ALLOC]]) : (!llvm.ptr, !llvm.ptr) -> i32
+    // CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<(i32, f64)> : (i64) -> !llvm.ptr<struct<(i32, f64)>>
+    // CHECK-NEXT: %[[EL0:.*]] = llvm.getelementptr %[[ALLOC]][0, 0] : (!llvm.ptr<struct<(i32, f64)>>) -> !llvm.ptr<i32>
+    // CHECK-NEXT: llvm.store %[[ARG0]], %[[EL0]] : !llvm.ptr<i32>
+    // CHECK-NEXT: %[[EL1:.*]] = llvm.getelementptr %[[ALLOC]][0, 1] : (!llvm.ptr<struct<(i32, f64)>>) -> !llvm.ptr<f64>
+    // CHECK-NEXT: llvm.store %[[EXT]], %[[EL1]] : !llvm.ptr<f64>
+    // CHECK-NEXT: %[[ARGPTR:.*]] = llvm.bitcast %[[ALLOC]] : !llvm.ptr<struct<(i32, f64)>> to !llvm.ptr<i8>
+    // CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ARGPTR]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> i32
     gpu.printf "Hello: %d\n" %arg0, %arg1 : i32, f32
     gpu.return
   }

>From fd6c88d876636fc1a671436ef04cf65e2c4bc483 Mon Sep 17 00:00:00 2001
From: bjacob <jacob.benoit.1 at gmail.com>
Date: Wed, 25 Oct 2023 11:41:24 -0400
Subject: [PATCH 3/3] Add missing `linalg.batch_vecmat` named op (#70218)

Linalg currently has these named ops:
* `matmul`
* `matvec`
* `vecmat`
* `batch_matmul`
* `batch_matvec`

But it does not have:
* `batch_vecmat`

This PRs adds that for consistency, and I have a short-term need for it
( https://github.com/openxla/iree/issues/15158 ), so not having this
would cause some contortion on my end.
---
 .../Linalg/IR/LinalgNamedStructuredOps.yaml   | 68 +++++++++++++++++++
 .../linalg/opdsl/ops/core_named_ops.py        | 18 +++++
 .../Dialect/Linalg/generalize-named-ops.mlir  | 25 +++++++
 3 files changed, 111 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 44bcbbab2bbe9de..4198ba76b4d61f5 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1796,6 +1796,74 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_arg: B
 --- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: batch_vecmat
+  cpp_class_name: BatchVecmatOp
+  doc: |-
+    Performs a batched matrix-vector multiplication.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+  implements:
+  - LinalgContractionOpInterface
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: A
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
+  - !LinalgOperandDefConfig
+    name: B
+    kind: input_tensor
+    type_var: T2
+    shape_map: affine_map<()[s0, s1, s2] -> (s0, s1, s2)>
+  - !LinalgOperandDefConfig
+    name: C
+    kind: output_tensor
+    type_var: U
+    shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
+    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2, d1)>
+    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
+  iterator_types:
+  - parallel
+  - parallel
+  - reduction
+  assignments:
+  - !ScalarAssign
+    arg: C
+    value: !ScalarExpression
+      scalar_fn:
+        kind: binary
+        fn_name: add
+        operands:
+        - !ScalarExpression
+          scalar_arg: C
+        - !ScalarExpression
+          scalar_fn:
+            kind: binary
+            fn_name: mul
+            operands:
+            - !ScalarExpression
+              scalar_fn:
+                kind: type
+                fn_name: cast_signed
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: A
+            - !ScalarExpression
+              scalar_fn:
+                kind: type
+                fn_name: cast_signed
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: B
+--- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: dot
   cpp_class_name: DotOp
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index e8bdb9180471927..52ad44670787299 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -517,6 +517,24 @@ def batch_matvec(
     )
 
 
+ at linalg_structured_op
+def batch_vecmat(
+    A=TensorDef(T1, Batch, S.K),
+    B=TensorDef(T2, Batch, S.K, S.N),
+    C=TensorDef(U, Batch, S.N, output=True),
+):
+    """Performs a batched matrix-vector multiplication.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    domain(D.b, D.n, D.k)
+    implements(ContractionOpInterface)
+    C[D.b, D.n] += TypeFn.cast_signed(U, A[D.b, D.k]) * TypeFn.cast_signed(
+        U, B[D.b, D.k, D.n]
+    )
+
+
 @linalg_structured_op
 def dot(A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U, output=True)):
     """Performs a dot product of two vectors to a scalar result.
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 54cc0defc1f8cd8..2259d47eb2b2b0d 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -251,6 +251,31 @@ func.func @generalize_batch_matm_vec(%lhs : memref<?x?x?xi8>, %rhs: memref<?x?xi
 
 // -----
 
+func.func @generalize_batch_vecmat(%lhs : memref<?x?xi8>, %rhs: memref<?x?x?xi8>,  %out: memref<?x?xf32>) {
+  linalg.batch_vecmat ins(%lhs, %rhs: memref<?x?xi8>, memref<?x?x?xi8>)
+                     outs(%out: memref<?x?xf32>)
+  return
+}
+// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
+// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK: @generalize_batch_vecmat
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]}
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?xi8>, memref<?x?x?xi8>)
+// CHECK-SAME: outs(%{{.+}} : memref<?x?xf32>)
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: f32)
+// CHECK:            %[[BBARG0_F32:.+]] = arith.sitofp %[[BBARG0]] : i8 to f32
+// CHECK:            %[[BBARG1_F32:.+]] = arith.sitofp %[[BBARG1]] : i8 to f32
+// CHECK:            %[[MUL:.+]] = arith.mulf %[[BBARG0_F32]], %[[BBARG1_F32]]
+// CHECK:            %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]]
+// CHECK:            linalg.yield %[[ADD]] : f32
+
+// -----
+
 func.func @batch_reduce_gemm(%lhs: memref<7x8x9xf32>, %rhs: memref<7x9x8xf32>, %out: memref<8x8xf32>) {
   linalg.batch_reduce_matmul ins(%lhs, %rhs: memref<7x8x9xf32>, memref<7x9x8xf32>)
                              outs(%out: memref<8x8xf32>)



More information about the Mlir-commits mailing list