On Tuesday 10 June 2008, Sepherosa Ziehau wrote:
> On 6/10/08, Aggelos Economopoulos <[EMAIL PROTECTED]> wrote:
> > On Friday 06 June 2008, Aggelos Economopoulos wrote:
> >  [...]
> >
> > > OK, same thing, but now it's the pcbs. TCP is "easy". The inpcb is 
> > > inserted on
> >  > a per-cpu hash table so that the corresponding protocol thread runs on 
> > the
> >  > same cpu. Some tcpcb fields however, are accessed directly from 
> > user-thread
> >  > context. The interesting fields are:
> >  >
> >  > t_flags:
> >  >       need to do early copyin / delayed copyout in so_pr_ctloutput
> >
> >
> > So, I was thinking something like the following:
> >
> >  diff --git a/sys/kern/uipc_msg.c b/sys/kern/uipc_msg.c
> >  index fde4c93..e786e02 100644
> >  --- a/sys/kern/uipc_msg.c
> >  +++ b/sys/kern/uipc_msg.c
> >  @@ -35,6 +35,7 @@
> >
> >   #include <sys/param.h>
> >   #include <sys/systm.h>
> >  +#include <sys/kernel.h>
> >   #include <sys/msgport.h>
> >   #include <sys/protosw.h>
> >   #include <sys/socket.h>
> >  @@ -379,22 +380,38 @@ so_pru_sopoll(struct socket *so, int events, struct 
> > ucred *cred)
> >         return (error);
> >   }
> >
> >  +MALLOC_DEFINE(M_SOPT, "sopt", "sopt temp storage");
> >  +
> >   int
> >   so_pr_ctloutput(struct socket *so, struct sockopt *sopt)
> >   {
> >  -       return ((*so->so_proto->pr_ctloutput)(so, sopt));
> >   #ifdef gag     /* does copyin and copyout deep inside stack XXX JH */
> >  +       return ((*so->so_proto->pr_ctloutput)(so, sopt));
> >  +#else
> >         struct netmsg_pr_ctloutput msg;
> >         lwkt_port_t port;
> >         int error;
> >  +       void *uval;
> >  +
> >  +       uval = sopt->sopt_val;
> >
> >  -       port = so->so_proto->pr_mport(so, NULL);
> >  +       /* we keep duplicate copies, but for option {s,g}etting who cares? 
> > */
> >  +       sopt->sopt_val = kmalloc(sopt->sopt_valsize, M_SOPT, M_WAITOK);
> >  +       error = copyin(uval, sopt->sopt_val, sopt->sopt_valsize);
> >  +       if (error)
> >  +               goto out;
> >  +       port = so->so_proto->pr_mport(so, NULL, NULL, XXX);
> >         netmsg_init(&msg.nm_netmsg, &curthread->td_msgport, 0,
> >                     netmsg_pru_ctloutput);
> >         msg.nm_prfn = so->so_proto->pr_ctloutput;
> >         msg.nm_so = so;
> >         msg.nm_sopt = sopt;
> >         error = lwkt_domsg(port, &msg.nm_netmsg.nm_lmsg, 0);
> >  +       if (error)
> >  +               goto out;
> >  +       error = copyout(sopt->sopt_val, uval, sopt->sopt_valsize);
> >  +out:
> >  +       kfree(sopt->sopt_val, M_SOPT);
> >         return (error);
> >   #endif
> >   }
> >
> >  But before I update all callees to remove copy{in,out}(), does anybody have
> >  any objections? Also, what should be the value of XXX? Perhaps ctloutput
> >  belongs to pr_usrreqs...
> 
> I don't think above change is enough.  You need to change
> sooptcopy{in,out}() or at least set sopt->sopt_td to NULL.  Please
> take a look at ip_ctloutput() for what the above patch may break.

Yes, I know. I'm in the middle of updating the callees. This is how I'm doing
it:

diff --git a/sys/netinet/ip_output.c b/sys/netinet/ip_output.c
index 9278ffa..549352f 100644
--- a/sys/netinet/ip_output.c
+++ b/sys/netinet/ip_output.c
@@ -1418,8 +1418,7 @@ ip_ctloutput(struct socket *so, struct sockopt *sopt)
                                break;
                        }
                        m->m_len = sopt->sopt_valsize;
-                       error = sooptcopyin(sopt, mtod(m, char *), m->m_len,
-                                           m->m_len);
+                       bcopy(sopt->sopt_val, mtod(m, void *), m->m_len);
 
                        return (ip_pcbopts(sopt->sopt_name, &inp->inp_options,
                                           m));
@@ -1434,10 +1433,11 @@ ip_ctloutput(struct socket *so, struct sockopt *sopt)
                case IP_RECVIF:
                case IP_RECVTTL:
                case IP_FAITH:
-                       error = sooptcopyin(sopt, &optval, sizeof optval,
-                                           sizeof optval);
-                       if (error)
+                       if (sopt->sopt_valsize != sizeof optval) {
+                               error = EINVAL;
                                break;
+                       }
+                       optval = *(int *)sopt->sopt_val;
 
                        switch (sopt->sopt_name) {
                        case IP_TOS:
@@ -1826,6 +1826,13 @@ ip_setmoptions(struct sockopt *sopt, struct ip_moptions 
**imop)
                imo->imo_num_memberships = 0;
        }
 
+#define getval(var)    \
+       if (sopt->sopt_valsize != sizeof var) { \
+               error = EINVAL;                 \
+               break;                          \
+       }                                       \
+       bcopy(sopt->sopt_val, &var, sizeof var) \
+
        switch (sopt->sopt_name) {
        /* store an index number for the vif you wanna use in the send */
        case IP_MULTICAST_VIF:
@@ -1833,9 +1840,7 @@ ip_setmoptions(struct sockopt *sopt, struct ip_moptions 
**imop)
                        error = EOPNOTSUPP;
                        break;
                }
-               error = sooptcopyin(sopt, &i, sizeof i, sizeof i);
-               if (error)
-                       break;
+               getval(i);
                if (!legal_vif_num(i) && (i != -1)) {
                        error = EINVAL;
                        break;
@@ -1847,9 +1852,7 @@ ip_setmoptions(struct sockopt *sopt, struct ip_moptions 
**imop)
                /*
                 * Select the interface for outgoing multicast packets.
                 */
-               error = sooptcopyin(sopt, &addr, sizeof addr, sizeof addr);
-               if (error)
-                       break;
+               getval(addr);
                /*
                 * INADDR_ANY is used to remove a previous selection.
                 * When no interface is selected, a default one is
@@ -1888,15 +1891,11 @@ ip_setmoptions(struct sockopt *sopt, struct ip_moptions 
**imop)
                 */
                if (sopt->sopt_valsize == 1) {
                        u_char ttl;
-                       error = sooptcopyin(sopt, &ttl, 1, 1);
-                       if (error)
-                               break;
+                       getval(ttl);
                        imo->imo_multicast_ttl = ttl;
                } else {
                        u_int ttl;
-                       error = sooptcopyin(sopt, &ttl, sizeof ttl, sizeof ttl);
-                       if (error)
-                               break;
+                       getval(ttl);
                        if (ttl > 255)
                                error = EINVAL;
                        else
@@ -1914,17 +1913,12 @@ ip_setmoptions(struct sockopt *sopt, struct ip_moptions 
**imop)
                if (sopt->sopt_valsize == 1) {
                        u_char loop;
 
-                       error = sooptcopyin(sopt, &loop, 1, 1);
-                       if (error)
-                               break;
+                       getval(loop);
                        imo->imo_multicast_loop = !!loop;
                } else {
                        u_int loop;
 
-                       error = sooptcopyin(sopt, &loop, sizeof loop,
-                                           sizeof loop);
-                       if (error)
-                               break;
+                       getval(loop);
                        imo->imo_multicast_loop = !!loop;
                }
                break;
@@ -1934,9 +1928,7 @@ ip_setmoptions(struct sockopt *sopt, struct ip_moptions 
**imop)
                 * Add a multicast group membership.
                 * Group must be a valid IP multicast address.
                 */
-               error = sooptcopyin(sopt, &mreq, sizeof mreq, sizeof mreq);
-               if (error)
-                       break;
+               getval(mreq);
 
                if (!IN_MULTICAST(ntohl(mreq.imr_multiaddr.s_addr))) {
                        error = EINVAL;
@@ -2015,10 +2007,8 @@ ip_setmoptions(struct sockopt *sopt, struct ip_moptions 
**imop)
                 * Drop a multicast group membership.
                 * Group must be a valid IP multicast address.
                 */
-               error = sooptcopyin(sopt, &mreq, sizeof mreq, sizeof mreq);
-               if (error)
-                       break;
-
+               getval(mreq);
+#undef getval
                if (!IN_MULTICAST(ntohl(mreq.imr_multiaddr.s_addr))) {
                        error = EINVAL;
                        break;
diff --git a/sys/netinet/tcp_usrreq.c b/sys/netinet/tcp_usrreq.c
index 4434133..4cedd42 100644
--- a/sys/netinet/tcp_usrreq.c
+++ b/sys/netinet/tcp_usrreq.c
@@ -1156,14 +1156,14 @@ tcp_ctloutput(struct socket *so, struct sockopt *sopt)
 
        switch (sopt->sopt_dir) {
        case SOPT_SET:
+               if (sopt->sopt_valsize != sizeof optval) {
+                       error = EINVAL;
+                       break;
+               }
+               optval = *(int *)sopt->sopt_valsize;
                switch (sopt->sopt_name) {
                case TCP_NODELAY:
                case TCP_NOOPT:
-                       error = sooptcopyin(sopt, &optval, sizeof optval,
-                                           sizeof optval);
-                       if (error)
-                               break;
-
                        switch (sopt->sopt_name) {
                        case TCP_NODELAY:
                                opt = TF_NODELAY;
@@ -1183,11 +1183,6 @@ tcp_ctloutput(struct socket *so, struct sockopt *sopt)
                        break;
 
                case TCP_NOPUSH:
-                       error = sooptcopyin(sopt, &optval, sizeof optval,
-                                           sizeof optval);
-                       if (error)
-                               break;
-
                        if (optval)
                                tp->t_flags |= TF_NOPUSH;
                        else {
@@ -1197,11 +1192,6 @@ tcp_ctloutput(struct socket *so, struct sockopt *sopt)
                        break;
 
                case TCP_MAXSEG:
-                       error = sooptcopyin(sopt, &optval, sizeof optval,
-                                           sizeof optval);
-                       if (error)
-                               break;
-
                        if (optval > 0 && optval <= tp->t_maxseg)
                                tp->t_maxseg = optval;
                        else
@@ -1232,8 +1222,11 @@ tcp_ctloutput(struct socket *so, struct sockopt *sopt)
                        error = ENOPROTOOPT;
                        break;
                }
-               if (error == 0)
-                       error = sooptcopyout(sopt, &optval, sizeof optval);
+               if (error == 0) {
+                       sopt->sopt_valsize = min(sizeof optval,
+                                                sopt->sopt_valsize);
+                       bcopy(&optval, sopt->sopt_val, sopt->sopt_valsize);
+               }
                break;
        }
        crit_exit();

Does that look OK to you?

Thanks,
Aggelos

Reply via email to