Here's a diff that restructures packet handling to allow easier addition
of SNMPv3.

diff --git a/snmp.c b/snmp.c
index 7fac777..b2d5cfc 100644
--- a/snmp.c
+++ b/snmp.c
@@ -32,6 +32,10 @@
 
 static struct ber_element *
     snmp_resolve(struct snmp_agent *, struct ber_element *, int);
+static char *
+    snmp_package(struct snmp_agent *, struct ber_element *, size_t *);
+static struct ber_element *
+    snmp_unpackage(struct snmp_agent *, char *, size_t);
 
 struct snmp_agent *
 snmp_connect_v12(int fd, enum snmp_version version, const char *community)
@@ -171,19 +175,16 @@ fail:
 static struct ber_element *
 snmp_resolve(struct snmp_agent *agent, struct ber_element *pdu, int reply)
 {
-       struct ber_element *message, *varbind;
+       struct ber_element *varbind;
        struct ber_oid oid;
        struct timespec start, now;
        struct pollfd pfd;
-       struct ber ber;
+       char *message;
        ssize_t len;
        long long reqid, rreqid;
-       long long version;
-       char *community;
        short direction;
        int to, nfds, ret;
        int tries;
-       void *ptr;
        char buf[READ_BUF_SIZE];
 
        if (ber_scanf_elements(pdu, "{i", &reqid) != 0) {
@@ -192,23 +193,8 @@ snmp_resolve(struct snmp_agent *agent, struct ber_element 
*pdu, int reply)
                return NULL;
        }
 
-       if ((message = ber_add_sequence(NULL)) == NULL) {
-               ber_free_elements(pdu);
-               return NULL;
-       }
-       if (ber_printf_elements(message, "dse", agent->version,
-           agent->community, pdu) == NULL) {
-               ber_free_elements(pdu);
-               ber_free_elements(message);
+       if ((message = snmp_package(agent, pdu, &len)) == NULL)
                return NULL;
-       }
-       memset(&ber, 0, sizeof(ber));
-       ber_set_application(&ber, smi_application);
-       len = ber_write_elements(&ber, message);
-       ber_free_elements(message);
-       message = NULL;
-       if (ber_get_writebuf(&ber, &ptr) < 1)
-               goto fail;
 
        clock_gettime(CLOCK_MONOTONIC, &start);
        memcpy(&now, &start, sizeof(now));
@@ -236,7 +222,7 @@ snmp_resolve(struct snmp_agent *agent, struct ber_element 
*pdu, int reply)
                                goto fail;
                }
                if (direction == POLLOUT) {
-                       ret = send(agent->fd, ptr, len, MSG_DONTWAIT);
+                       ret = send(agent->fd, message, len, MSG_DONTWAIT);
                        if (ret == -1)
                                goto fail;
                        if (ret < len) {
@@ -253,25 +239,10 @@ snmp_resolve(struct snmp_agent *agent, struct ber_element 
*pdu, int reply)
                        errno = ECONNRESET;
                if (ret <= 0)
                        goto fail;
-               ber_set_readbuf(&ber, buf, ret);
-               if ((message = ber_read_elements(&ber, NULL)) == NULL) {
-                       direction = POLLOUT;
+               if ((pdu = snmp_unpackage(agent, buf, ret)) == NULL) {
                        tries--;
-                       continue;
-               }
-               if (ber_scanf_elements(message, "{ise", &version, &community,
-                   &pdu) != 0) {
-                       errno = EPROTO;
                        direction = POLLOUT;
-                       tries--;
-                       continue;
-               }
-               /* Skip invalid packets; should not happen */
-               if (version != agent->version ||
-                   strcmp(community, agent->community) != 0) {
                        errno = EPROTO;
-                       direction = POLLOUT;
-                       tries--;
                        continue;
                }
                /* Validate pdu format and check request id */
@@ -297,17 +268,96 @@ snmp_resolve(struct snmp_agent *agent, struct ber_element 
*pdu, int reply)
                                break;
                        }
                }
-               if (varbind != NULL)
-                       continue;
 
-               ber_unlink_elements(message->be_sub->be_next);
-               ber_free_elements(message);
-               ber_free(&ber);
+               free(message);
                return pdu;
        }
 
+fail:
+       free(message);
+       return NULL;
+}
+
+static char *
+snmp_package(struct snmp_agent *agent, struct ber_element *pdu, size_t *len)
+{
+       struct ber ber;
+       struct ber_element *message;
+       ssize_t ret;
+       char *buf, *packet = NULL;
+
+       bzero(&ber, sizeof(ber));
+       ber_set_application(&ber, smi_application);
+
+       if ((message = ber_add_sequence(NULL)) == NULL) {
+               ber_free_elements(pdu);
+               goto fail;
+       }
+
+       switch (agent->version) {
+       case SNMP_V1:
+       case SNMP_V2C:
+               if (ber_printf_elements(message, "dse", agent->version,
+                   agent->community, pdu) == NULL) {
+                       ber_free_elements(pdu);
+                       goto fail;
+               }
+               break;
+       case SNMP_V3:
+               break;
+       }
+
+       if ((ret = ber_write_elements(&ber, message)) == -1)
+               goto fail;
+       *len = (size_t) ret;
+       if (ber_get_writebuf(&ber, (void **)&buf) != -1 &&
+           (packet = malloc(ret)) != NULL)
+               memcpy(packet, buf, ret);
+       ber_free(&ber);
+
 fail:
        ber_free_elements(message);
+       return packet;
+}
+
+static struct ber_element *
+snmp_unpackage(struct snmp_agent *agent, char *buf, size_t buflen)
+{
+       struct ber ber;
+       enum snmp_version version;
+       char *community;
+       struct ber_element *pdu;
+       struct ber_element *message = NULL, *payload;
+
+       bzero(&ber, sizeof(ber));
+       ber_set_application(&ber, smi_application);
+
+       ber_set_readbuf(&ber, buf, buflen);
+       if ((message = ber_read_elements(&ber, NULL)) == NULL)
+               return NULL;
        ber_free(&ber);
+
+       if (ber_scanf_elements(message, "{de", &version, &payload) != 0)
+               goto fail;
+
+       if (version != agent->version)
+               goto fail;
+
+       switch (version)
+       {
+       case SNMP_V1:
+       case SNMP_V2C:
+               if (ber_scanf_elements(payload, "se", &community, &pdu) == -1)
+                       goto fail;
+               ber_unlink_elements(payload);
+               ber_free_elements(message);
+               return pdu;
+       case SNMP_V3:
+               break;
+       }
+       /* NOTREACHED */
+
+fail:
+       ber_free_elements(message);
        return NULL;
 }
diff --git a/snmpc.c b/snmpc.c
index dc48b10..104bb7c 100644
--- a/snmpc.c
+++ b/snmpc.c
@@ -45,6 +45,7 @@ int snmpc_get(int, char *[]);
 int snmpc_walk(int, char *[]);
 int snmpc_trap(int, char *[]);
 int snmpc_mibtree(int, char *[]);
+struct snmp_agent *snmpc_connect(char *, char *);
 int snmpc_parseagent(char *, char *);
 int snmpc_print(struct ber_element *);
 __dead void snmpc_printerror(enum snmp_error, char *);
@@ -304,9 +305,7 @@ snmpc_get(int argc, char *argv[])
        if (argc < 2)
                usage();
 
-       agent = snmp_connect_v12(snmpc_parseagent(argv[0], "161"), version,
-           community);
-       if (agent == NULL)
+       if ((agent = snmpc_connect(argv[0], "161")) == NULL)
                err(1, "%s", snmp_app->name);
        agent->timeout = timeout;
        agent->retries = retries;
@@ -372,8 +371,7 @@ snmpc_walk(int argc, char *argv[])
                usage();
        oids = argc == 1 ? mib : argv[1];
 
-       agent = snmp_connect_v12(snmpc_parseagent(argv[0], "161"), version, 
community);
-       if (agent == NULL)
+       if ((agent = snmpc_connect(argv[0], "161"))== NULL)
                err(1, "%s", snmp_app->name);
        agent->timeout = timeout;
        agent->retries = retries;
@@ -495,9 +493,7 @@ snmpc_trap(int argc, char *argv[])
        if (version == SNMP_V1)
                errx(1, "trap is not supported for snmp v1");
 
-       agent = snmp_connect_v12(snmpc_parseagent(argv[0], "162"),
-           version, community);
-       if (agent == NULL)
+       if ((agent = snmpc_connect(argv[0], "162")) == NULL)
                err(1, "%s", snmp_app->name);
 
        if (pledge("stdio", NULL) == -1)
@@ -693,6 +689,18 @@ snmpc_mibtree(int argc, char *argv[])
        return 0;
 }
 
+struct snmp_agent *
+snmpc_connect(char *host, char *port)
+{
+       switch (version) {
+       case SNMP_V1:
+       case SNMP_V2C:
+               return snmp_connect_v12(snmpc_parseagent(host, port), version,
+                   community);
+       }
+       return NULL;
+}
+
 int
 snmpc_print(struct ber_element *elm)
 {

Reply via email to