Here's a diff that restructures packet handling to allow easier addition of SNMPv3.
diff --git a/snmp.c b/snmp.c index 7fac777..b2d5cfc 100644 --- a/snmp.c +++ b/snmp.c @@ -32,6 +32,10 @@ static struct ber_element * snmp_resolve(struct snmp_agent *, struct ber_element *, int); +static char * + snmp_package(struct snmp_agent *, struct ber_element *, size_t *); +static struct ber_element * + snmp_unpackage(struct snmp_agent *, char *, size_t); struct snmp_agent * snmp_connect_v12(int fd, enum snmp_version version, const char *community) @@ -171,19 +175,16 @@ fail: static struct ber_element * snmp_resolve(struct snmp_agent *agent, struct ber_element *pdu, int reply) { - struct ber_element *message, *varbind; + struct ber_element *varbind; struct ber_oid oid; struct timespec start, now; struct pollfd pfd; - struct ber ber; + char *message; ssize_t len; long long reqid, rreqid; - long long version; - char *community; short direction; int to, nfds, ret; int tries; - void *ptr; char buf[READ_BUF_SIZE]; if (ber_scanf_elements(pdu, "{i", &reqid) != 0) { @@ -192,23 +193,8 @@ snmp_resolve(struct snmp_agent *agent, struct ber_element *pdu, int reply) return NULL; } - if ((message = ber_add_sequence(NULL)) == NULL) { - ber_free_elements(pdu); - return NULL; - } - if (ber_printf_elements(message, "dse", agent->version, - agent->community, pdu) == NULL) { - ber_free_elements(pdu); - ber_free_elements(message); + if ((message = snmp_package(agent, pdu, &len)) == NULL) return NULL; - } - memset(&ber, 0, sizeof(ber)); - ber_set_application(&ber, smi_application); - len = ber_write_elements(&ber, message); - ber_free_elements(message); - message = NULL; - if (ber_get_writebuf(&ber, &ptr) < 1) - goto fail; clock_gettime(CLOCK_MONOTONIC, &start); memcpy(&now, &start, sizeof(now)); @@ -236,7 +222,7 @@ snmp_resolve(struct snmp_agent *agent, struct ber_element *pdu, int reply) goto fail; } if (direction == POLLOUT) { - ret = send(agent->fd, ptr, len, MSG_DONTWAIT); + ret = send(agent->fd, message, len, MSG_DONTWAIT); if (ret == -1) goto fail; if (ret < len) { @@ -253,25 +239,10 @@ snmp_resolve(struct snmp_agent *agent, struct ber_element *pdu, int reply) errno = ECONNRESET; if (ret <= 0) goto fail; - ber_set_readbuf(&ber, buf, ret); - if ((message = ber_read_elements(&ber, NULL)) == NULL) { - direction = POLLOUT; + if ((pdu = snmp_unpackage(agent, buf, ret)) == NULL) { tries--; - continue; - } - if (ber_scanf_elements(message, "{ise", &version, &community, - &pdu) != 0) { - errno = EPROTO; direction = POLLOUT; - tries--; - continue; - } - /* Skip invalid packets; should not happen */ - if (version != agent->version || - strcmp(community, agent->community) != 0) { errno = EPROTO; - direction = POLLOUT; - tries--; continue; } /* Validate pdu format and check request id */ @@ -297,17 +268,96 @@ snmp_resolve(struct snmp_agent *agent, struct ber_element *pdu, int reply) break; } } - if (varbind != NULL) - continue; - ber_unlink_elements(message->be_sub->be_next); - ber_free_elements(message); - ber_free(&ber); + free(message); return pdu; } +fail: + free(message); + return NULL; +} + +static char * +snmp_package(struct snmp_agent *agent, struct ber_element *pdu, size_t *len) +{ + struct ber ber; + struct ber_element *message; + ssize_t ret; + char *buf, *packet = NULL; + + bzero(&ber, sizeof(ber)); + ber_set_application(&ber, smi_application); + + if ((message = ber_add_sequence(NULL)) == NULL) { + ber_free_elements(pdu); + goto fail; + } + + switch (agent->version) { + case SNMP_V1: + case SNMP_V2C: + if (ber_printf_elements(message, "dse", agent->version, + agent->community, pdu) == NULL) { + ber_free_elements(pdu); + goto fail; + } + break; + case SNMP_V3: + break; + } + + if ((ret = ber_write_elements(&ber, message)) == -1) + goto fail; + *len = (size_t) ret; + if (ber_get_writebuf(&ber, (void **)&buf) != -1 && + (packet = malloc(ret)) != NULL) + memcpy(packet, buf, ret); + ber_free(&ber); + fail: ber_free_elements(message); + return packet; +} + +static struct ber_element * +snmp_unpackage(struct snmp_agent *agent, char *buf, size_t buflen) +{ + struct ber ber; + enum snmp_version version; + char *community; + struct ber_element *pdu; + struct ber_element *message = NULL, *payload; + + bzero(&ber, sizeof(ber)); + ber_set_application(&ber, smi_application); + + ber_set_readbuf(&ber, buf, buflen); + if ((message = ber_read_elements(&ber, NULL)) == NULL) + return NULL; ber_free(&ber); + + if (ber_scanf_elements(message, "{de", &version, &payload) != 0) + goto fail; + + if (version != agent->version) + goto fail; + + switch (version) + { + case SNMP_V1: + case SNMP_V2C: + if (ber_scanf_elements(payload, "se", &community, &pdu) == -1) + goto fail; + ber_unlink_elements(payload); + ber_free_elements(message); + return pdu; + case SNMP_V3: + break; + } + /* NOTREACHED */ + +fail: + ber_free_elements(message); return NULL; } diff --git a/snmpc.c b/snmpc.c index dc48b10..104bb7c 100644 --- a/snmpc.c +++ b/snmpc.c @@ -45,6 +45,7 @@ int snmpc_get(int, char *[]); int snmpc_walk(int, char *[]); int snmpc_trap(int, char *[]); int snmpc_mibtree(int, char *[]); +struct snmp_agent *snmpc_connect(char *, char *); int snmpc_parseagent(char *, char *); int snmpc_print(struct ber_element *); __dead void snmpc_printerror(enum snmp_error, char *); @@ -304,9 +305,7 @@ snmpc_get(int argc, char *argv[]) if (argc < 2) usage(); - agent = snmp_connect_v12(snmpc_parseagent(argv[0], "161"), version, - community); - if (agent == NULL) + if ((agent = snmpc_connect(argv[0], "161")) == NULL) err(1, "%s", snmp_app->name); agent->timeout = timeout; agent->retries = retries; @@ -372,8 +371,7 @@ snmpc_walk(int argc, char *argv[]) usage(); oids = argc == 1 ? mib : argv[1]; - agent = snmp_connect_v12(snmpc_parseagent(argv[0], "161"), version, community); - if (agent == NULL) + if ((agent = snmpc_connect(argv[0], "161"))== NULL) err(1, "%s", snmp_app->name); agent->timeout = timeout; agent->retries = retries; @@ -495,9 +493,7 @@ snmpc_trap(int argc, char *argv[]) if (version == SNMP_V1) errx(1, "trap is not supported for snmp v1"); - agent = snmp_connect_v12(snmpc_parseagent(argv[0], "162"), - version, community); - if (agent == NULL) + if ((agent = snmpc_connect(argv[0], "162")) == NULL) err(1, "%s", snmp_app->name); if (pledge("stdio", NULL) == -1) @@ -693,6 +689,18 @@ snmpc_mibtree(int argc, char *argv[]) return 0; } +struct snmp_agent * +snmpc_connect(char *host, char *port) +{ + switch (version) { + case SNMP_V1: + case SNMP_V2C: + return snmp_connect_v12(snmpc_parseagent(host, port), version, + community); + } + return NULL; +} + int snmpc_print(struct ber_element *elm) {