Fix dc_queuecmd portability (and changed syntax!).
[doldaconnect.git] / lib / uilib.c
index 2bd9575..031fc16 100644 (file)
 #include <sys/socket.h>
 #include <netinet/in.h>
 #include <arpa/inet.h>
+#include <sys/un.h>
 #include <fcntl.h>
 #include <netdb.h>
 #include <sys/poll.h>
+#include <pwd.h>
 #ifdef HAVE_RESOLVER
 #include <arpa/nameser.h>
 #include <resolv.h>
 #endif
 
 #include <doldaconnect/uilib.h>
-#include <doldaconnect/utils.h>
+#include <utils.h>
+
+#define DOLCON_SRV_NAME "_dolcon._tcp"
 
 #define RESP_END -1
 #define RESP_DSC 0
@@ -92,9 +96,12 @@ static int state = -1;
 static int fd = -1;
 static iconv_t ichandle;
 static int resetreader = 1;
-static char *dchostname = NULL;
 static struct addrinfo *hostlist = NULL, *curhost = NULL;
-static int servport;
+struct {
+    char *hostname;
+    int family;
+    int sentcreds;
+} servinfo;
 
 static struct dc_response *makeresp(void)
 {
@@ -273,9 +280,9 @@ void dc_disconnect(void)
     while((resp = dc_getresp()) != NULL)
        dc_freeresp(resp);
     dc_uimisc_disconnected();
-    if(dchostname != NULL)
-       free(dchostname);
-    dchostname = NULL;
+    if(servinfo.hostname != NULL)
+       free(servinfo.hostname);
+    memset(&servinfo, 0, sizeof(servinfo));
 }
 
 void dc_freeresp(struct dc_response *resp)
@@ -371,7 +378,7 @@ int dc_queuecmd(int (*callback)(struct dc_response *), void *data, ...)
     struct qcmd *qcmd;
     int num, freepart;
     va_list al;
-    char *final;
+    char *final, *sarg;
     wchar_t **toks;
     wchar_t *buf;
     wchar_t *part, *tpart;
@@ -383,7 +390,7 @@ int dc_queuecmd(int (*callback)(struct dc_response *), void *data, ...)
     va_start(al, data);
     while((part = va_arg(al, wchar_t *)) != NULL)
     {
-       if(!wcscmp(part, L"%%a"))
+       if(!wcscmp(part, L"%a"))
        {
            for(toks = va_arg(al, wchar_t **); *toks != NULL; toks++)
            {
@@ -403,25 +410,36 @@ int dc_queuecmd(int (*callback)(struct dc_response *), void *data, ...)
        } else {
            if(*part == L'%')
            {
-               /* This demands that all arguments that are passed to the
-                * function are of equal length, that of an int. I know
-                * that GCC does that on IA32 platforms, but I do not know
-                * which other platforms and compilers that it applies
-                * to. If this breaks your platform, please mail me about
-                * it.
-                */
-               part = vswprintf2(tpart = (part + 1), al);
-               for(; *tpart != L'\0'; tpart++)
+               tpart = part + 1;
+               if(!wcscmp(tpart, L"i"))
                {
-                   if(*tpart == L'%')
+                   freepart = 1;
+                   part = swprintf2(L"%i", va_arg(al, int));
+               } else if(!wcscmp(tpart, L"s")) {
+                   freepart = 1;
+                   part = icmbstowcs(sarg = va_arg(al, char *), NULL);
+                   if(part == NULL)
                    {
-                       if(tpart[1] == L'%')
-                           tpart++;
-                       else
-                           va_arg(al, int);
+                       if(buf != NULL)
+                           free(buf);
+                       return(-1);
                    }
+               } else if(!wcscmp(tpart, L"ls")) {
+                   part = va_arg(al, wchar_t *);
+               } else if(!wcscmp(tpart, L"ll")) {
+                   freepart = 1;
+                   part = swprintf2(L"%lli", va_arg(al, long long));
+               } else if(!wcscmp(tpart, L"f")) {
+                   freepart = 1;
+                   part = swprintf2(L"%f", va_arg(al, double));
+               } else if(!wcscmp(tpart, L"x")) {
+                   freepart = 1;
+                   part = swprintf2(L"%x", va_arg(al, int));
+               } else {
+                   if(buf != NULL)
+                       free(buf);
+                   return(-1);
                }
-               freepart = 1;
            } else {
                freepart = 0;
            }
@@ -500,10 +518,7 @@ int dc_handleread(void)
        if(ret)
        {
            int newfd;
-           struct sockaddr_storage addr;
-           struct sockaddr_in *ipv4;
-           struct sockaddr_in6 *ipv6;
-           
+
            for(curhost = curhost->ai_next; curhost != NULL; curhost = curhost->ai_next)
            {
                if((newfd = socket(curhost->ai_family, curhost->ai_socktype, curhost->ai_protocol)) < 0)
@@ -516,20 +531,7 @@ int dc_handleread(void)
                dup2(newfd, fd);
                close(newfd);
                fcntl(fd, F_SETFL, fcntl(fd, F_GETFL) | O_NONBLOCK);
-               memcpy(&addr, curhost->ai_addr, curhost->ai_addrlen);
-               if(addr.ss_family == AF_INET)
-               {
-                   ipv4 = (struct sockaddr_in *)&addr;
-                   ipv4->sin_port = htons(servport);
-               }
-#ifdef HAVE_IPV6
-               if(addr.ss_family == AF_INET6)
-               {
-                   ipv6 = (struct sockaddr_in6 *)&addr;
-                   ipv6->sin6_port = htons(servport);
-               }
-#endif
-               if(connect(fd, (struct sockaddr *)&addr, curhost->ai_addrlen))
+               if(connect(fd, (struct sockaddr *)curhost->ai_addr, curhost->ai_addrlen))
                {
                    if(errno == EINPROGRESS)
                        return(0);
@@ -544,6 +546,9 @@ int dc_handleread(void)
                return(-1);
            }
        }
+       if(curhost->ai_canonname != NULL)
+           servinfo.hostname = sstrdup(curhost->ai_canonname);
+       servinfo.family = curhost->ai_family;
        state = 1;
        resetreader = 1;
        break;
@@ -755,17 +760,48 @@ int dc_handleread(void)
     return(0);
 }
 
+static void mkcreds(struct msghdr *msg)
+{
+    struct ucred *ucred;
+    static char buf[CMSG_SPACE(sizeof(*ucred))];
+    struct cmsghdr *cmsg;
+    
+    msg->msg_control = buf;
+    msg->msg_controllen = sizeof(buf);
+    cmsg = CMSG_FIRSTHDR(msg);
+    cmsg->cmsg_level = SOL_SOCKET;
+    cmsg->cmsg_type = SCM_CREDENTIALS;
+    cmsg->cmsg_len = CMSG_LEN(sizeof(*ucred));
+    ucred = (struct ucred *)CMSG_DATA(cmsg);
+    ucred->pid = getpid();
+    ucred->uid = getuid();
+    ucred->gid = getgid();
+    msg->msg_controllen = cmsg->cmsg_len;
+}
+
 int dc_handlewrite(void)
 {
     int ret;
     int errnobak;
+    struct msghdr msg;
+    struct iovec bufvec;
     
     switch(state)
     {
     case 1:
        if(queue->buflen > 0)
        {
-           ret = send(fd, queue->buf, queue->buflen, MSG_NOSIGNAL | MSG_DONTWAIT);
+           memset(&msg, 0, sizeof(msg));
+           msg.msg_iov = &bufvec;
+           msg.msg_iovlen = 1;
+           bufvec.iov_base = queue->buf;
+           bufvec.iov_len = queue->buflen;
+           if((servinfo.family == PF_UNIX) && !servinfo.sentcreds)
+           {
+               mkcreds(&msg);
+               servinfo.sentcreds = 1;
+           }
+           ret = sendmsg(fd, &msg, MSG_NOSIGNAL | MSG_DONTWAIT);
            if(ret < 0)
            {
                if((errno == EAGAIN) || (errno == EINTR))
@@ -784,6 +820,10 @@ int dc_handlewrite(void)
 }
 
 #ifdef HAVE_RESOLVER
+/*
+ * It kind of sucks that libresolv doesn't have any DNS parsing
+ * routines. We'll have to do it manually.
+ */
 static char *readname(unsigned char *msg, unsigned char *eom, unsigned char **p)
 {
     char *name, *tname;
@@ -858,15 +898,10 @@ static int getsrvrr(char *name, char **host, int *port)
            return(-1);
     }
     /* res_querydomain doesn't work for some reason */
-    name2 = smalloc(strlen("_dolcon._tcp.") + strlen(name) + 2);
-    strcpy(name2, "_dolcon._tcp.");
-    strcat(name2, name);
-    len = strlen(name2);
-    if(name2[len - 1] != '.')
-    {
-       name2[len] = '.';
-       name2[len + 1] = 0;
-    }
+    if(name[strlen(name) - 1] == '.')
+       name2 = sprintf2("%s.%s", DOLCON_SRV_NAME, name);
+    else
+       name2 = sprintf2("%s.%s.", DOLCON_SRV_NAME, name);
     ret = res_query(name2, C_IN, T_SRV, buf, sizeof(buf));
     if(ret < 0)
     {
@@ -874,12 +909,20 @@ static int getsrvrr(char *name, char **host, int *port)
        return(-1);
     }
     eom = buf + ret;
+    /*
+     * Assume transaction ID is correct.
+     *
+     * Flags check: FA0F masks in request/response flag, opcode,
+     * truncated flag and status code, and ignores authoritativeness,
+     * recursion flags and DNSSEC and reserved bits.
+     */
     flags = (buf[2] << 8) + buf[3];
     if((flags & 0xfa0f) != 0x8000)
     {
        free(name2);
        return(-1);
     }
+    /* Skip the query entries */
     num = (buf[4] << 8) + buf[5];
     p = buf + 12;
     for(i = 0; i < num; i++)
@@ -889,8 +932,9 @@ static int getsrvrr(char *name, char **host, int *port)
            free(name2);
            return(-1);
        }
-       p += 4;
+       p += 4; /* Type and class */
     }
+    /* Parse answers */
     num = (buf[6] << 8) + buf[7];
     for(i = 0; i < num; i++)
     {
@@ -903,7 +947,7 @@ static int getsrvrr(char *name, char **host, int *port)
        type += *(p++);
        class = *(p++) << 8;
        class += *(p++);
-       p += 4;
+       p += 4; /* TTL */
        len = *(p++) << 8;
        len += *(p++);
        if((class == C_IN) && (type == T_SRV) && !strcmp(rrname, name2))
@@ -942,75 +986,165 @@ static int getsrvrr(char *name, char **host, int *port)
 }
 #endif
 
-int dc_connect(char *host, int port)
+static struct addrinfo *gaicat(struct addrinfo *l1, struct addrinfo *l2)
 {
-    struct addrinfo hint;
-    struct sockaddr_storage addr;
-    struct sockaddr_in *ipv4;
-#ifdef HAVE_IPV6
-    struct sockaddr_in6 *ipv6;
-#endif
-    struct qcmd *qcmd;
-    char *newhost;
-    int getsrv, freehost;
-    int errnobak;
+    struct addrinfo *p;
+    
+    if(l1 == NULL)
+       return(l2);
+    for(p = l1; p->ai_next != NULL; p = p->ai_next);
+    p->ai_next = l2;
+    return(l1);
+}
+
+/* This isn't actually correct, in any sense of the word. It only
+ * works on systems whose getaddrinfo implementation saves the
+ * sockaddr in the same malloc block as the struct addrinfo. Those
+ * systems include at least FreeBSD and glibc-based systems, though,
+ * so it should not be any immediate threat, and it allows me to not
+ * implement a getaddrinfo wrapper. It can always be changed, should
+ * the need arise. */
+static struct addrinfo *unixgai(int type, char *path)
+{
+    void *buf;
+    struct addrinfo *ai;
+    struct sockaddr_un *un;
+    
+    buf = smalloc(sizeof(*ai) + sizeof(*un));
+    memset(buf, 0, sizeof(*ai) + sizeof(*un));
+    ai = (struct addrinfo *)buf;
+    un = (struct sockaddr_un *)(buf + sizeof(*ai));
+    ai->ai_flags = 0;
+    ai->ai_family = AF_UNIX;
+    ai->ai_socktype = type;
+    ai->ai_protocol = 0;
+    ai->ai_addrlen = sizeof(*un);
+    ai->ai_addr = (struct sockaddr *)un;
+    ai->ai_canonname = NULL;
+    ai->ai_next = NULL;
+    un->sun_family = PF_UNIX;
+    strncpy(un->sun_path, path, sizeof(un->sun_path) - 1);
+    return(ai);
+}
+
+static struct addrinfo *resolvtcp(char *name, int port)
+{
+    struct addrinfo hint, *ret;
+    char tmp[32];
     
-    if(fd >= 0)
-       dc_disconnect();
-    state = -1;
-    freehost = 0;
-    if(port < 0)
-    {
-       port = 1500;
-       getsrv = 1;
-    } else {
-       getsrv = 0;
-    }
     memset(&hint, 0, sizeof(hint));
     hint.ai_socktype = SOCK_STREAM;
-    if(getsrv)
+    hint.ai_flags = AI_NUMERICSERV | AI_CANONNAME;
+    snprintf(tmp, sizeof(tmp), "%i", port);
+    if(!getaddrinfo(name, tmp, &hint, &ret))
+       return(ret);
+    return(NULL);
+}
+
+static struct addrinfo *resolvsrv(char *name)
+{
+    struct addrinfo *ret;
+    char *realname;
+    int port;
+    
+    if(getsrvrr(name, &realname, &port))
+       return(NULL);
+    ret = resolvtcp(realname, port);
+    free(realname);
+    return(ret);
+}
+
+static struct addrinfo *resolvhost(char *host)
+{
+    char *p, *hp;
+    struct addrinfo *ret;
+    int port;
+    
+    if(strchr(host, '/'))
+       return(unixgai(SOCK_STREAM, host));
+    if((strchr(host, ':') == NULL) && ((ret = resolvsrv(host)) != NULL))
+       return(ret);
+    ret = NULL;
+    if((*host == '[') && ((p = strchr(host, ']')) != NULL))
     {
-       if(!getsrvrr(host, &newhost, &port))
-       {
-           host = newhost;
-           freehost = 1;
+       hp = memcpy(smalloc(p - host), host + 1, (p - host) - 1);
+       hp[(p - host) - 1] = 0;
+       if(strchr(hp, ':') != NULL) {
+           port = 0;
+           if(*(++p) == ':')
+               port = atoi(p + 1);
+           if(port == 0)
+               port = 1500;
+           ret = resolvtcp(hp, port);
        }
+       free(hp);
+    }
+    if(ret != NULL)
+       return(ret);
+    hp = sstrdup(host);
+    port = 0;
+    if((p = strrchr(hp, ':')) != NULL) {
+       *(p++) = 0;
+       port = atoi(p);
     }
-    servport = port;
+    if(port == 0)
+       port = 1500;
+    ret = resolvtcp(hp, port);
+    free(hp);
+    if(ret != NULL)
+       return(ret);
+    return(NULL);
+}
+
+static struct addrinfo *defaulthost(void)
+{
+    struct addrinfo *ret;
+    struct passwd *pwd;
+    char *tmp;
+    char dn[1024];
+    
+    if(((tmp = getenv("DCSERVER")) != NULL) && *tmp)
+       return(resolvhost(tmp));
+    ret = NULL;
+    if((getuid() != 0) && ((pwd = getpwuid(getuid())) != NULL))
+    {
+       tmp = sprintf2("/tmp/doldacond-%s", pwd->pw_name);
+       ret = gaicat(ret, unixgai(SOCK_STREAM, tmp));
+       free(tmp);
+    }
+    ret = gaicat(ret, unixgai(SOCK_STREAM, "/var/run/doldacond.sock"));
+    ret = gaicat(ret, resolvtcp("localhost", 1500));
+    if(!getdomainname(dn, sizeof(dn)) && *dn && strcmp(dn, "(none)"))
+       ret = gaicat(ret, resolvsrv(dn));
+    return(ret);
+}
+
+int dc_connect(char *host)
+{
+    struct qcmd *qcmd;
+    int errnobak;
+    
+    if(fd >= 0)
+       dc_disconnect();
+    state = -1;
     if(hostlist != NULL)
        freeaddrinfo(hostlist);
-    if(getaddrinfo(host, NULL, &hint, &hostlist))
-    {
-       errno = ENONET;
-       if(freehost)
-           free(host);
+    if(!host || !*host)
+       hostlist = defaulthost();
+    else
+       hostlist = resolvhost(host);
+    if(hostlist == NULL)
        return(-1);
-    }
     for(curhost = hostlist; curhost != NULL; curhost = curhost->ai_next)
     {
        if((fd = socket(curhost->ai_family, curhost->ai_socktype, curhost->ai_protocol)) < 0)
        {
            errnobak = errno;
-           if(freehost)
-               free(host);
            errno = errnobak;
            return(-1);
        }
        fcntl(fd, F_SETFL, fcntl(fd, F_GETFL) | O_NONBLOCK);
-       memcpy(&addr, curhost->ai_addr, curhost->ai_addrlen);
-       if(addr.ss_family == AF_INET)
-       {
-           ipv4 = (struct sockaddr_in *)&addr;
-           ipv4->sin_port = htons(port);
-       }
-#ifdef HAVE_IPV6
-       if(addr.ss_family == AF_INET6)
-       {
-           ipv6 = (struct sockaddr_in6 *)&addr;
-           ipv6->sin6_port = htons(port);
-       }
-#endif
-       if(connect(fd, (struct sockaddr *)&addr, curhost->ai_addrlen))
+       if(connect(fd, (struct sockaddr *)curhost->ai_addr, curhost->ai_addrlen))
        {
            if(errno == EINPROGRESS)
            {
@@ -1020,17 +1154,15 @@ int dc_connect(char *host, int port)
            close(fd);
            fd = -1;
        } else {
+           if(curhost->ai_canonname != NULL)
+               servinfo.hostname = sstrdup(curhost->ai_canonname);
+           servinfo.family = curhost->ai_family;
            state = 1;
            break;
        }
     }
     qcmd = makeqcmd(NULL);
     resetreader = 1;
-    if(dchostname != NULL)
-       free(dchostname);
-    dchostname = sstrdup(host);
-    if(freehost)
-       free(host);
     return(fd);
 }
 
@@ -1104,7 +1236,25 @@ void dc_freeires(struct dc_intresp *ires)
     free(ires);
 }
 
+int dc_checkprotocol(struct dc_response *resp, int revision)
+{
+    struct dc_intresp *ires;
+    int low, high;
+    
+    if(resp->code != 201)
+       return(-1);
+    resp->curline = 0;
+    if((ires = dc_interpret(resp)) == NULL)
+       return(-1);
+    low = ires->argv[0].val.num;
+    high = ires->argv[0].val.num;
+    dc_freeires(ires);
+    if((revision < low) || (revision > high))
+       return(-1);
+    return(0);
+}
+
 const char *dc_gethostname(void)
 {
-    return(dchostname);
+    return(servinfo.hostname);
 }