Hello,

I just wrote an implementation for mpn_mulhigh_basecase for Broadwell-type processors (that is, x86_64 with BMI2 and ADX instructions) based on Torbjörn's `mpn_mullo_basecase'.

It is currently declared on the form

mp_limb_t flint_mpn_mulhigh_basecase(mp_ptr rp, mp_srcptr xp, mp_srcptr yp, mp_size_t n),

as it was developed for FLINT (Fast Library for Number Theory). Note that `rp' cannot be aliased with `xp' or `yp'.

It will compute an approximation of the upper most `n + 1' limbs of `{xp, n}' times `{yp, n}', where the upper most `n' limbs are pushed to `rp[0]', ..., `rp[n - 1]' and the least significant computed limb is returned (via %rax). This returned limb should have an error of something along `n' ULPs.

Note that this differs from MPFR's (private) function `mpfr_mulhigh_n', which computes the approximation of the upper most `n' limbs into `rp[n]', ..., `rp[2 * n - 1]', where `rp[n]' has an error of something along `n' ULPs at most.

Feel free to change it according to your needs (perhaps you do not want to compute `n + 1' limbs, but rather `n' limbs).

If this code will be used in GMP, feel free to remove the copyright claim for FLINT and put my name (spelled Albin Ahlbäck) in the GMP copyright claim instead.

Just some notes:

- We use our own M4 syntax for the beginning and ending of the function, but it should be easy to translate to GMP's syntax. - It currently only works for n > 5 (I believe) as we in FLINT have specialized routines for small n.
- It would be nice to avoid pushing five register, and only push four.
- Reduce the size of the `L(end)' code, and try to avoid multiple jumps therein. - Move the code-snippet of `L(f2)' to just above `L(b2)', so that no jump is needed in between. (This currently does not work because `L(end)' as well as this code-snippet is too large for a relative 8-bit jump.) - Start out with an mul_1 sequence with just a mulx+add+adc chain, just like in `mpn_mullo_basecase'. - Remove the first multiplication in each `L(fX)' and put it in `L(end)' instead. - The `adcx' instruction in `L(fX)' can be removed (then one needs to adjust the `L(bX)'-label), but I found it to be slower. Can we remove it and somehow maintain the same performance?

Best,
Albin
dnl  X64-64 mpn_mullo_basecase optimised for Intel Broadwell.

dnl  Contributed to the GNU project by Torbjorn Granlund.

dnl  Copyright 2017 Free Software Foundation, Inc.

dnl  This file is part of the GNU MP Library.
dnl
dnl  The GNU MP Library is free software; you can redistribute it and/or modify
dnl  it under the terms of either:
dnl
dnl    * the GNU Lesser General Public License as published by the Free
dnl      Software Foundation; either version 3 of the License, or (at your
dnl      option) any later version.
dnl
dnl  or
dnl
dnl    * the GNU General Public License as published by the Free Software
dnl      Foundation; either version 2 of the License, or (at your option) any
dnl      later version.
dnl
dnl  or both in parallel, as here.
dnl
dnl  The GNU MP Library is distributed in the hope that it will be useful, but
dnl  WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
dnl  or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
dnl  for more details.
dnl
dnl  You should have received copies of the GNU General Public License and the
dnl  GNU Lesser General Public License along with the GNU MP Library.  If not,
dnl  see https://www.gnu.org/licenses/.
dnl
dnl   Copyright (C) 2024 Albin Ahlbäck
dnl
dnl   This file is part of FLINT.
dnl
dnl   FLINT is free software: you can redistribute it and/or modify it under
dnl   the terms of the GNU Lesser General Public License (LGPL) as published
dnl   by the Free Software Foundation; either version 3 of the License, or
dnl   (at your option) any later version.  See <https://www.gnu.org/licenses/>.
dnl

include(`config.m4')
include(`src/mpn_extras/broadwell/asm-defs.m4')

define(`rp',       `%rdi')
define(`ap',       `%rsi')
define(`bp_param', `%rdx')
define(`n',        `%rcx')

define(`bp',       `%r8')
define(`jmpreg',   `%r9')
define(`nn',       `%r10')
define(`m',        `%r13')
define(`mm',       `%r14')

define(`rx',       `%rax')

define(`r0',       `%r11')
define(`r1',       `%rbx')
define(`r2',       `%rbp')
define(`r3',       `%r12')

# Idea: Do similar to mpn_mullo_basecase for Skylake.

.text

.global FUNC(flint_mpn_mulhigh_basecase)
ALIGN(32)
TYPE(flint_mpn_mulhigh_basecase)

FUNC(flint_mpn_mulhigh_basecase):
        .cfi_startproc
        mov     bp_param, bp
        lea     -1*8(ap,n,8), ap        # ap += n - 1

        push    %rbx
        push    %rbp
        push    %r12
        push    %r13
        push    %r14

        # Initial triangle
        #       h
        #     h x
        #   h x x
        # h x x x
        # x x x x
define(`s0', `jmpreg')
define(`s1', `m')
define(`s2', `mm')
define(`s3', `nn')
        mov     0*8(bp), %rdx
        xor     R32(s3), R32(s3)
        mulx    -1*8(ap), rx, rx
        mulx    0*8(ap), s0, r0
        add     s0, rx
        adc     s3, r0

        mov     1*8(bp), %rdx
        mulx    -2*8(ap), s1, s1
        mulx    -1*8(ap), r3, r2
        mulx    0*8(ap), s0, r1
        add     r3, rx
        adc     s0, r0
        adc     s3, r1
        add     s1, rx
        adc     r2, r0
        adc     s3, r1

        mov     2*8(bp), %rdx
        mulx    -3*8(ap), s0, s0
        mulx    -2*8(ap), r3, s1
        add     s0, rx
        adc     s1, r0
        mulx    -1*8(ap), s0, s1
        mulx    0*8(ap), %rdx, r2
        adc     s1, r1
        adc     s3, r2
        add     r3, rx
        adc     s0, r0
        adc     %rdx, r1
        adc     s3, r2

        mov     3*8(bp), %rdx
        mulx    -4*8(ap), s1, s1
        mulx    -3*8(ap), s0, s2
        add     s1, rx
        adc     s2, r0
        mulx    -2*8(ap), s1, r3
        mulx    -1*8(ap), s2, s3
        adc     r3, r1
        adc     s3, r2
        mulx    0*8(ap), %rdx, r3
        adc     $0, r3
        add     s0, rx
        adc     s1, r0
        adc     s2, r1
        mov     r0, 0*8(rp)
        mov     r1, 1*8(rp)
        adc     %rdx, r2
        adc     $0, r3
        mov     r2, 2*8(rp)
        mov     r3, 3*8(rp)
undefine(`s0')
undefine(`s1')
undefine(`s2')
undefine(`s3')

        # Addmul chains
        # - m = -8 * n_cur      (n_cur is the 4 at the start)
        # - mm = -8 * (n - 1)   (where n is the original n)
        # - n keeps track of how many loops to do in the addmul-loop.
        # - nn keeps track of initial n between loops.
        lea     -1*8(,n,8), R32(mm)
        lea     4*8(bp), bp
        lea     -3*8(ap), ap
        mov     $-4*8, m                # m <- -8 * 4
        neg     mm                      # mm <- -8 * (n - 1)
        mov     0*8(bp), %rdx
        xor     R32(nn), R32(nn)        # nn <- 0
        xor     R32(n), R32(n)          # n <- 0
        mulx    -2*8(ap), r1, r1
        adcx    r1, rx

L(f4):  mulx    -1*8(ap), r2, r3
        mulx    0*8(ap), r0, r1
        adox    r2, rx
        adcx    r3, r0
        lea     3*8(ap), ap
        lea     -5*8(rp), rp
        lea     L(f5)(%rip), jmpreg
        jmp     L(b4)

L(f0):  mulx    -1*8(ap), r2, r3
        mulx    0*8(ap), r0, r1
        adox    r2, rx
        adcx    r3, r0
        lea     -1*8(ap), ap
        lea     -1*8(rp), rp
        lea     L(f1)(%rip), jmpreg
        jmp     L(b0)

L(f1):  mulx    -1*8(ap), r0, r1
        mulx    0*8(ap), r2, r3
        adox    r0, rx
        adcx    r1, r2
        lea     1(nn), R32(nn)
        lea     1(n), R32(n)
        lea     L(f2)(%rip), jmpreg
        jmp     L(b1)

L(f7):  mulx    -1*8(ap), r0, r1
        mulx    0*8(ap), r2, r3
        adox    r0, rx
        adcx    r1, r2
        lea     -2*8(ap), ap
        lea     -2*8(rp), rp
        lea     L(f0)(%rip), jmpreg
        jmp     L(b7)

L(f2):  mulx    -1*8(ap), r2, r3
        mulx    0*8(ap), r0, r1
        adox    r2, rx
        adcx    r3, r0
        lea     1*8(ap), ap
        lea     1*8(rp), rp
        mulx    0*8(ap), r2, r3
        lea     L(f3)(%rip), jmpreg
        jmp     L(b2)

L(end): adox    0*8(rp), r2
        mov     r2, 0*8(rp)
        adox    n, r3           # n = 0
        adc     n, r3           # n = 0
        add     m, ap           # Reset ap
        mov     r3, 1*8(rp)
        lea     -1*8(m), m
        lea     1*8(bp), bp     # Increase bp
        lea     2*8(rp,m), rp   # Reset rp
        mov     0*8(bp), %rdx   # Load bp
        cmp     R32(m), R32(mm)
        jge     L(jmp)
        # If |m| < |mm|: goto jmpreg, but first do high part
        or      R32(nn), R32(n) # Reset n, CF and OF
        mulx    -2*8(ap), r1, r1
        adcx    r1, rx
        jmp     *jmpreg
        # If |m| > |mm|: goto fin
L(jmp): jg      L(fin)
        # If |m| = |mm|: goto jmpreg
        or      R32(nn), R32(n) # Reset n, clear CF and OF
        jmp     *jmpreg

        ALIGN(32)
L(b2):  adox    -1*8(rp), r0
        adcx    r1, r2
        mov     r0, -1*8(rp)
        jrcxz   L(end)  # Jump if n = 0
L(b1):  mulx    1*8(ap), r0, r1
        adox    0*8(rp), r2
        lea     -1(n), R32(n)
        mov     r2, 0*8(rp)
        adcx    r3, r0
L(b0):  mulx    2*8(ap), r2, r3
        adcx    r1, r2
        adox    1*8(rp), r0
        mov     r0, 1*8(rp)
L(b7):  mulx    3*8(ap), r0, r1
        lea     8*8(ap), ap
        adcx    r3, r0
        adox    2*8(rp), r2
        mov     r2, 2*8(rp)
L(b6):  mulx    -4*8(ap), r2, r3
        adox    3*8(rp), r0
        adcx    r1, r2
        mov     r0, 3*8(rp)
L(b5):  mulx    -3*8(ap), r0, r1
        adcx    r3, r0
        adox    4*8(rp), r2
        mov     r2, 4*8(rp)
L(b4):  mulx    -2*8(ap), r2, r3
        adox    5*8(rp), r0
        adcx    r1, r2
        mov     r0, 5*8(rp)
L(b3):  adox    6*8(rp), r2
        mulx    -1*8(ap), r0, r1
        mov     r2, 6*8(rp)
        lea     8*8(rp), rp
        adcx    r3, r0
        mulx    0*8(ap), r2, r3
        jmp     L(b2)

L(f6):  mulx    -1*8(ap), r2, r3
        mulx    0*8(ap), r0, r1
        adox    r2, rx
        adcx    r3, r0
        lea     5*8(ap), ap
        lea     -3*8(rp), rp
        lea     L(f7)(%rip), jmpreg
        jmp     L(b6)

L(f5):  mulx    -1*8(ap), r0, r1
        mulx    0*8(ap), r2, r3
        adox    r0, rx
        adcx    r1, r2
        lea     4*8(ap), ap
        lea     -4*8(rp), rp
        lea     L(f6)(%rip), jmpreg
        jmp     L(b5)

L(f3):  mulx    -1*8(ap), r0, r1
        mulx    0*8(ap), r2, r3
        adox    r0, rx
        adcx    r1, r2
        lea     2*8(ap), ap
        lea     -6*8(rp), rp
        lea     L(f4)(%rip), jmpreg
        jmp     L(b3)

L(fin): pop     %r14
        pop     %r13
        pop     %r12
        pop     %rbp
        pop     %rbx

        ret
.flint_mpn_mulhigh_basecase_end:
SIZE(flint_mpn_mulhigh_basecase, .flint_mpn_mulhigh_basecase_end)
.cfi_endproc
_______________________________________________
gmp-devel mailing list
gmp-devel@gmplib.org
https://gmplib.org/mailman/listinfo/gmp-devel

Reply via email to