Hi everyone, I've put together a patch for 6.0-stable that adds domain name matching support to rebound(8). The patch is quite rough at the moment.
The config is as follows: match "local." 10.0.0.53 match "." 8.8.8.8 Requests to foo.local. are sent over to 10.0.0.53, all other requests go to 8.8.8.8. In my implementation, the first match wins. General drawbacks: - rebound has to parse DNS requests. I tried to keep the parsing code as small as possible to avoid security problems. Drawbacks in current implementation: - No caching for DNS requests over TCP. I am planning to implement this via a unified cache that works for both UDP and TCP. - non-blocking connect(2) support for TCP. The original code handled that but I reworked it because I wanted to get it working first. What do you think? =================================================================== RCS file: /cvs/src/usr.sbin/rebound/rebound.c,v retrieving revision 1.65 diff -u -p -r1.65 rebound.c --- rebound.c 2 Jul 2016 17:09:09 -0000 1.65 +++ rebound.c 16 Sep 2016 12:29:39 -0000 @@ -37,6 +37,8 @@ #include <getopt.h> #include <stdarg.h> +#define LEN(x) (sizeof (x) / sizeof *(x)) + uint16_t randomid(void); static struct timespec now; @@ -100,6 +102,13 @@ struct request { }; static TAILQ_HEAD(, request) reqfifo; +struct match { + char pat[256]; + struct sockaddr_storage to; + TAILQ_ENTRY(match) entry; +}; +static TAILQ_HEAD(, match) matches; + static int conncount; static int connmax; static uint64_t conntotal; @@ -215,10 +224,94 @@ servfail(int ud, uint16_t id, struct soc sendto(ud, &pkt, sizeof(pkt), 0, fromaddr, fromlen); } +static size_t +readn(int fd, void *buf, size_t n) +{ + size_t total = 0; + size_t r; + + while (n > 0) { + r = read(fd, buf + total, n); + if (r == 0 || r == -1) + return -1; + total += r; + n -= r; + } + return total; +} + +static size_t +writen(int fd, void *buf, size_t n) +{ + size_t total = 0; + size_t r; + + while (n > 0) { + r = write(fd, buf + total, n); + if (r == 0 || r == -1) + return -1; + total += r; + n -= r; + } + return total; +} + +int +parsedomain(uint8_t *buf, size_t buflen, char *host, size_t hostlen) +{ + uint8_t *bp = &buf[0], *be = &buf[buflen]; + char *hp = &host[0], *he = &host[hostlen]; + + bp += sizeof(struct dnspacket); + if (bp >= be) + return -1; + for (;;) { + uint8_t len = *bp++; + if (len == 0) + break; + if (bp + len >= be || hp + len >= he) + return -1; + memcpy(hp, bp, len); + bp += len; + hp += len; + *hp++ = '.'; + if (hp == he) + return -1; + } + *hp = '\0'; + return 0; +} + +int +matchreq(uint8_t *buf, size_t buflen, struct sockaddr_storage *to) +{ + char host[65536]; + struct match *match; + + /* XXX: check flags/qdcount? */ + if (parsedomain(buf, buflen, host, sizeof(host)) == -1) + return -1; + TAILQ_FOREACH(match, &matches, entry) { + size_t hlen = strlen(host); + size_t glen = strlen(match->pat); + if (hlen < glen) + continue; + if (strcmp(&host[hlen - glen], match->pat) == 0) { + memcpy(to, &match->to, sizeof(*to)); + logmsg(LOG_DEBUG, "matched domain %s with %s", + host, match->pat); + /* first match wins */ + return 0; + } + } + return -1; +} + static struct request * -newrequest(int ud, struct sockaddr *remoteaddr) +newrequest(int ud) { - struct sockaddr from; + struct sockaddr_storage remoteaddr; + struct sockaddr from, *to; socklen_t fromlen; struct request *req; uint8_t buf[65536]; @@ -271,13 +364,17 @@ newrequest(int ud, struct sockaddr *remo } req->cacheent = hit; - req->s = socket(remoteaddr->sa_family, SOCK_DGRAM, 0); + if (matchreq(buf, r, &remoteaddr) == -1) + goto fail; + to = (struct sockaddr *)&remoteaddr; + + req->s = socket(to->sa_family, SOCK_DGRAM, 0); if (req->s == -1) goto fail; TAILQ_INSERT_TAIL(&reqfifo, req, fifo); - if (connect(req->s, remoteaddr, remoteaddr->sa_len) == -1) { + if (connect(req->s, to, to->sa_len) == -1) { logmsg(LOG_NOTICE, "failed to connect (%d)", errno); if (errno == EADDRNOTAVAIL) servfail(ud, req->clientid, &from, fromlen); @@ -335,36 +432,18 @@ sendreply(int ud, struct request *req) } static struct request * -tcpphasetwo(struct request *req) -{ - int error; - socklen_t len = sizeof(error); - - req->tcp = 2; - - if (getsockopt(req->s, SOL_SOCKET, SO_ERROR, &error, &len) == -1 || - error != 0) - goto fail; - if (setsockopt(req->client, SOL_SOCKET, SO_SPLICE, &req->s, - sizeof(req->s)) == -1) - goto fail; - if (setsockopt(req->s, SOL_SOCKET, SO_SPLICE, &req->client, - sizeof(req->client)) == -1) - goto fail; - - return req; - -fail: - freerequest(req); - return NULL; -} - -static struct request * -newtcprequest(int ld, struct sockaddr *remoteaddr) +newtcprequest(int ld) { + struct sockaddr_storage remoteaddr; + struct sockaddr *to; struct request *req; + uint8_t buf[65536]; + struct dnspacket *dnsreq; + uint16_t reqsize; int client; + dnsreq = (struct dnspacket *)&buf[2]; + client = accept(ld, NULL, 0); if (client == -1) { if (errno == ENFILE || errno == EMFILE) @@ -372,6 +451,24 @@ newtcprequest(int ld, struct sockaddr *r return NULL; } + if (readn(client, &reqsize, sizeof(reqsize)) == -1) { + close(client); + return NULL; + } + if (reqsize > sizeof(buf) - 2) { + close(client); + return NULL; + } + memcpy(buf, &reqsize, sizeof(reqsize)); + + reqsize = ntohs(reqsize); + if (readn(client, &buf[2], reqsize) == -1) { + close(client); + return NULL; + } + + /* XXX: unified cache handling for tcp/udp requests */ + if (!(req = calloc(1, sizeof(*req)))) { close(client); return NULL; @@ -383,18 +480,31 @@ newtcprequest(int ld, struct sockaddr *r req->ts.tv_sec += 30; req->tcp = 1; req->client = client; + req->s = -1; + + req->clientid = dnsreq->id; + req->reqid = randomid(); + dnsreq->id = req->reqid; - req->s = socket(remoteaddr->sa_family, SOCK_STREAM | SOCK_NONBLOCK, 0); + if (matchreq(&buf[2], reqsize, &remoteaddr) == -1) + goto fail; + to = (struct sockaddr *)&remoteaddr; + + req->s = socket(to->sa_family, SOCK_STREAM, 0); if (req->s == -1) goto fail; TAILQ_INSERT_TAIL(&reqfifo, req, fifo); - if (connect(req->s, remoteaddr, remoteaddr->sa_len) == -1) { - if (errno != EINPROGRESS) - goto fail; - } else { - return tcpphasetwo(req); + /* XXX: should really use non-blocking connect */ + if (connect(req->s, to, to->sa_len) == -1) { + logmsg(LOG_NOTICE, "failed to connect (%d)", errno); + goto fail; + } + + if (writen(req->s, buf, reqsize + 2) == -1) { + logmsg(LOG_NOTICE, "failed to write (%d)", errno); + goto fail; } return req; @@ -404,43 +514,133 @@ fail: return NULL; } +static void +sendtcpreply(struct request *req) +{ + uint8_t buf[65536]; + struct dnspacket *resp; + uint16_t reqsize; + + resp = (struct dnspacket *)&buf[2]; + + if (readn(req->s, &reqsize, sizeof(reqsize)) == -1) + return; + if (reqsize > sizeof(buf) - 2) + return; + memcpy(buf, &reqsize, sizeof(reqsize)); + + reqsize = ntohs(reqsize); + if (readn(req->s, &buf[2], reqsize) == -1) + return; + if (resp->id != req->reqid) + return; + resp->id = req->clientid; + + if (writen(req->client, buf, reqsize + 2) == -1) + return; + + /* XXX: cache handling */ +} + +static void +free_matches(void) +{ + struct match *match, *tmp; + + for (match = TAILQ_FIRST(&matches); match != NULL; match = tmp) { + tmp = TAILQ_NEXT(match, entry); + TAILQ_REMOVE(&matches, match, entry); + free(match); + } +} + static int -readconfig(FILE *conf, struct sockaddr_storage *remoteaddr) +readconfig(FILE *conf) { +#define KEYWORDIDX 0 +#define PATTERNIDX 1 +#define NSIDX 2 +#define NTOKENS 3 + char *tokens[NTOKENS], *p, *last; char buf[1024]; - struct sockaddr_in *sin = (struct sockaddr_in *)remoteaddr; - struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)remoteaddr; + struct sockaddr_in *sin; + struct sockaddr_in6 *sin6; + struct match *match; + int i; + + free_matches(); /* for SIGHUP */ + while (fgets(buf, sizeof(buf), conf) != NULL) { + buf[strcspn(buf, "\n")] = '\0'; + + /* tokenize line */ + for (i = 0, p = strtok_r(buf, " \t", &last); + p != NULL; + p = strtok_r(NULL, " \t", &last)) + if (i < LEN(tokens)) + tokens[i++] = p; + if (i != NTOKENS) + goto fail; - if (fgets(buf, sizeof(buf), conf) == NULL) - return -1; - buf[strcspn(buf, "\n")] = '\0'; + /* only recognize the match keyword so far */ + if (strcmp(tokens[KEYWORDIDX], "match") != 0) + goto fail; - memset(remoteaddr, 0, sizeof(*remoteaddr)); - if (inet_pton(AF_INET, buf, &sin->sin_addr) == 1) { - sin->sin_len = sizeof(*sin); - sin->sin_family = AF_INET; - sin->sin_port = htons(53); - return AF_INET; - } else if (inet_pton(AF_INET6, buf, &sin6->sin6_addr) == 1) { - sin6->sin6_len = sizeof(*sin6); - sin6->sin6_family = AF_INET6; - sin6->sin6_port = htons(53); - return AF_INET6; - } else { - return -1; + match = malloc(sizeof(*match)); + if (match == NULL) + goto fail; + + /* extract pattern */ + for (i = 0, p = tokens[PATTERNIDX]; *p != '\0'; p++) { + if (*p == '"') + continue; + if (i < LEN(match->pat) - 1) + match->pat[i++] = *p; + } + if (i == 0) { + /* empty pattern? bail */ + free(match); + goto fail; + } + match->pat[i] = '\0'; + + memset(&match->to, 0, sizeof(match->to)); + sin = (struct sockaddr_in *)&match->to; + sin6 = (struct sockaddr_in6 *)&match->to; + if (inet_pton(AF_INET, tokens[NSIDX], &sin->sin_addr) == 1) { + sin->sin_len = sizeof(*sin); + sin->sin_family = AF_INET; + sin->sin_port = htons(53); + } else if (inet_pton(AF_INET6, tokens[NSIDX], &sin6->sin6_addr) == 1) { + sin6->sin6_len = sizeof(*sin6); + sin6->sin6_family = AF_INET6; + sin6->sin6_port = htons(53); + } else { + free(match); + goto fail; + } + + TAILQ_INSERT_TAIL(&matches, match, entry); } + + /* we need at least one match rule */ + if (TAILQ_EMPTY(&matches)) + goto fail; + + return 0; +fail: + free_matches(); + return -1; } static int launch(FILE *conf, int ud, int ld, int kq) { - struct sockaddr_storage remoteaddr; struct kevent ch[2], kev[4]; struct timespec ts, *timeout = NULL; struct request *req; struct dnscache *ent; struct passwd *pwd; - int i, r, af; + int i, r; pid_t parent, child; parent = getpid(); @@ -476,9 +676,9 @@ launch(FILE *conf, int ud, int ld, int k if (pledge("stdio inet", NULL) == -1) logerr("pledge failed"); - af = readconfig(conf, &remoteaddr); + r = readconfig(conf); fclose(conf); - if (af == -1) + if (r == -1) logerr("parse error in config file"); EV_SET(&kev[0], ud, EVFILT_READ, EV_ADD, 0, 0, NULL); @@ -517,37 +717,26 @@ launch(FILE *conf, int ud, int ld, int k } else if (kev[i].filter == EVFILT_PROC) { logmsg(LOG_INFO, "parent died"); exit(0); - } else if (kev[i].filter == EVFILT_WRITE) { - req = kev[i].udata; - req = tcpphasetwo(req); - if (req) { - EV_SET(&ch[0], req->s, EVFILT_WRITE, - EV_DELETE, 0, 0, NULL); - EV_SET(&ch[1], req->s, EVFILT_READ, - EV_ADD, 0, 0, req); - kevent(kq, ch, 2, NULL, 0, NULL); - } } else if (kev[i].filter != EVFILT_READ) { logerr("don't know what happened"); } else if (kev[i].ident == ud) { - if ((req = newrequest(ud, - (struct sockaddr *)&remoteaddr))) { + if ((req = newrequest(ud))) { EV_SET(&ch[0], req->s, EVFILT_READ, EV_ADD, 0, 0, req); kevent(kq, ch, 1, NULL, 0, NULL); } } else if (kev[i].ident == ld) { - if ((req = newtcprequest(ld, - (struct sockaddr *)&remoteaddr))) { - EV_SET(&ch[0], req->s, - req->tcp == 1 ? EVFILT_WRITE : - EVFILT_READ, EV_ADD, 0, 0, req); + if ((req = newtcprequest(ld))) { + EV_SET(&ch[0], req->s, EVFILT_READ, + EV_ADD, 0, 0, req); kevent(kq, ch, 1, NULL, 0, NULL); } } else { req = kev[i].udata; if (req->tcp == 0) sendreply(ud, req); + else + sendtcpreply(req); freerequest(req); } } @@ -655,6 +844,7 @@ main(int argc, char **argv) TAILQ_INIT(&reqfifo); TAILQ_INIT(&cachefifo); + TAILQ_INIT(&matches); RB_INIT(&cachetree); memset(&bindaddr, 0, sizeof(bindaddr));