#include <errno.h>
#include <stdbool.h>
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/types.h>
#include <sys/ioctl.h>
#include <fcntl.h>
#include <unistd.h>
#include <dev/usb/usb.h>
#include <paths.h>

#define ADB_USB_CLASS		0xff
#define ADB_USB_SUBCLASS	0x42
#define ADB_USB_PROTOCOL	0x01

#define ADB_PACKET_MAX		4096

typedef uint32_t __le32_t;

struct adb_header {
	__le32_t op;
	__le32_t arg0;
	__le32_t arg1;
	__le32_t len;
	__le32_t sum;
	__le32_t nop;
};

static const char* strerrno(void)
{
	return strerror(errno);
}

static int open_endpoint(const char* const xname, unsigned ep, int flags)
{
	char path[sizeof(_PATH_DEV) + USB_MAX_DEVNAMELEN];

	int ret = snprintf(path, sizeof(path), _PATH_DEV "%s.%02u", xname, ep);
	if (ret < 0)
		return ret;

	if (ret >= (int)sizeof(path)) {
		errno = EINVAL;
		return -1;
	}

	return open(path, flags | O_CLOEXEC);
}

static void find_adb_endpoints(
	int fd, const char* const xname,
	const struct usb_device_info* const udi,
	const usb_config_descriptor_t* const cdesc,
	const usb_interface_descriptor_t* const idesc,
	unsigned* outep, unsigned* inep)
{
	unsigned outpipemask = 0, inpipemask = 0;

	for (int i = idesc->bNumEndpoints;;) {

		if (i == 0) {
			if (outpipemask == 0 || inpipemask == 0)
				break;

			return;
		}

		struct usb_endpoint_desc ued;
		ued.ued_config_index = USB_CURRENT_CONFIG_INDEX;
		ued.ued_interface_index = idesc->bInterfaceNumber;
		ued.ued_alt_index = idesc->bAlternateSetting;
		ued.ued_endpoint_index = --i;

		if (ioctl(fd, USB_GET_ENDPOINT_DESC, &ued, sizeof(ued)) != 0) {
			fprintf(stderr, "%s: %s: %s\n",
				xname, "USB_GET_ENDPOINT_DESC", strerrno());
			break;
		}

		if (UE_GET_XFERTYPE(ued.ued_desc.bmAttributes) != UE_BULK)
			break;

		unsigned mask = UGETW(ued.ued_desc.wMaxPacketSize);

		if (mask < sizeof(struct adb_header))
			break;

		if (((mask - 1) & mask) != 0)
			break;

		if (UE_GET_DIR(ued.ued_desc.bEndpointAddress) == UE_DIR_OUT) {
			if (outpipemask != 0)
				break;
			outpipemask = mask - 1;

			*outep = UE_GET_ADDR(ued.ued_desc.bEndpointAddress);

		} else {
			if (inpipemask != 0)
				break;
			inpipemask = mask - 1;

			*inep = UE_GET_ADDR(ued.ued_desc.bEndpointAddress);
		}
	}

	fprintf(stderr, "%s [%x/%x/%d/%d]: %s\n",
		xname,
		udi->udi_vendorNo,
		udi->udi_productNo,
		cdesc->bConfigurationValue,
		idesc->bInterfaceNumber,
		"bad ADB interface");

	exit(EXIT_FAILURE);
}

int main(int argc, char* argv[])
{
	if (argc != 2) {
		printf("usage: adb-ugen-panic ugenX\n");
		return EXIT_FAILURE;
	}

	const char* const xname = argv[1];

	int fd = open_endpoint(xname, 0, O_RDONLY);
	if (fd < 0) {
		fprintf(stderr, "%s.%02u: %s\n", xname, 0, strerrno());
		return EXIT_FAILURE;
	}

	struct usb_device_info udi;

	if (ioctl(fd, USB_GET_DEVICEINFO, &udi, sizeof(udi)) != 0) {
		fprintf(stderr, "%s: %s: %s\n",
			xname, "USB_GET_DEVICEINFO", strerrno());
		return EXIT_FAILURE;
	}

	struct usb_config_desc ucd;
	ucd.ucd_config_index = USB_CURRENT_CONFIG_INDEX;

	if (ioctl(fd, USB_GET_CONFIG_DESC, &ucd, sizeof(ucd)) != 0) {
		fprintf(stderr, "%s: %s: %s\n",
			xname, "USB_GET_CONFIG_DESC", strerrno());
		return EXIT_FAILURE;
	}

	unsigned outep = 0, inep = 0;

	struct usb_interface_desc uid;
	uid.uid_config_index = USB_CURRENT_CONFIG_INDEX;
	uid.uid_alt_index = USB_CURRENT_ALT_INDEX;

	for (int i = ucd.ucd_desc.bNumInterface;;) {
		if (i == 0) {
			fprintf(stderr, "%s: %s\n",
				xname, "no ADB interface");
			return EXIT_FAILURE;
		}

		uid.uid_interface_index = --i;

		if (ioctl(fd, USB_GET_INTERFACE_DESC, &uid, sizeof(uid)) != 0)
			continue;

		if (uid.uid_desc.bInterfaceClass != ADB_USB_CLASS ||
			uid.uid_desc.bInterfaceSubClass != ADB_USB_SUBCLASS ||
			uid.uid_desc.bInterfaceProtocol != ADB_USB_PROTOCOL) {

			continue;
		}

		find_adb_endpoints(
			fd, xname, &udi, &ucd.ucd_desc, &uid.uid_desc,
			&outep, &inep);

		break;
	}

	close(fd);

	int inpipefd = open_endpoint(xname, inep, O_RDONLY);
	if (inpipefd < 0) {
		fprintf(stderr, "%s.%02u: %s\n", xname, inep, strerrno());
		return EXIT_FAILURE;
	}

	printf("found ADB interface, send a signal to trigger panic\n");

	unsigned char buf[ADB_PACKET_MAX];

	for (;;) {
		ssize_t ret = read(inpipefd, buf, sizeof(buf));
		if (ret < 0) {
			fprintf(stderr, "%s.%02u: %s\n", xname, inep, strerrno());
			return EXIT_FAILURE;
		}

		printf("got %zd bytes, retrying\n", ret);
	}

	return EXIT_SUCCESS;
}
