3d2e6d9ecb24908b857ce59b9031ba794b591cbe
[ashd.git] / src / ratequeue.c
1 /*
2     ashd - A Sane HTTP Daemon
3     Copyright (C) 2008  Fredrik Tolf <fredrik@dolda2000.com>
4
5     This program is free software: you can redistribute it and/or modify
6     it under the terms of the GNU General Public License as published by
7     the Free Software Foundation, either version 3 of the License, or
8     (at your option) any later version.
9
10     This program is distributed in the hope that it will be useful,
11     but WITHOUT ANY WARRANTY; without even the implied warranty of
12     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13     GNU General Public License for more details.
14
15     You should have received a copy of the GNU General Public License
16     along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 */
18
19 #include <stdlib.h>
20 #include <stdio.h>
21 #include <unistd.h>
22 #include <errno.h>
23 #include <string.h>
24 #include <time.h>
25 #include <signal.h>
26 #include <assert.h>
27 #include <sys/poll.h>
28 #include <sys/socket.h>
29 #include <netinet/in.h>
30 #include <arpa/inet.h>
31
32 #ifdef HAVE_CONFIG_H
33 #include <config.h>
34 #endif
35 #include <utils.h>
36 #include <log.h>
37 #include <req.h>
38 #include <resp.h>
39 #include <proc.h>
40 #include <cf.h>
41
42 #define SBUCKETS 7
43
44 struct source {
45     int type;
46     char data[16];
47     unsigned int len, hash;
48 };
49
50 struct waiting {
51     struct hthead *req;
52     int fd;
53 };
54
55 struct bucket {
56     struct source id;
57     double level, last, etime, wtime;
58     typedbuf(struct waiting) brim;
59     int thpos, blocked;
60 };
61
62 struct btime {
63     struct bucket *bk;
64     double tm;
65 };
66
67 struct config {
68     double size, rate, retain, warnrate;
69     int brimsize;
70 };
71
72 static struct bucket *sbuckets[1 << SBUCKETS];
73 static struct bucket **buckets = sbuckets;
74 static int hashlen = SBUCKETS, nbuckets = 0;
75 static typedbuf(struct btime) timeheap;
76 static int child, reload;
77 static double now;
78 static const struct config defcfg = {
79     .size = 100, .rate = 10, .warnrate = 60,
80     .retain = 10, .brimsize = 10,
81 };
82 static struct config cf;
83
84 static double rtime(void)
85 {
86     static int init = 0;
87     static struct timespec or;
88     struct timespec ts;
89     
90     clock_gettime(CLOCK_MONOTONIC, &ts);
91     if(!init) {
92         or = ts;
93         init = 1;
94     }
95     return((ts.tv_sec - or.tv_sec) + ((ts.tv_nsec - or.tv_nsec) / 1000000000.0));
96 }
97
98 static struct source reqsource(struct hthead *req)
99 {
100     int i;
101     char *sa;
102     struct in_addr a4;
103     struct in6_addr a6;
104     struct source ret;
105     
106     ret = (struct source){};
107     if((sa = getheader(req, "X-Ash-Address")) != NULL) {
108         if(inet_pton(AF_INET, sa, &a4) == 1) {
109             ret.type = AF_INET;
110             memcpy(ret.data, &a4, ret.len = sizeof(a4));
111         } else if(inet_pton(AF_INET6, sa, &a6) == 1) {
112             ret.type = AF_INET6;
113             memcpy(ret.data, &a6, ret.len = sizeof(a6));
114         }
115     }
116     for(i = 0, ret.hash = ret.type; i < ret.len; i++)
117         ret.hash = (ret.hash * 31) + ret.data[i];
118     return(ret);
119 }
120
121 static int srccmp(const struct source *a, const struct source *b)
122 {
123     int c;
124     
125     if((c = a->len - b->len) != 0)
126         return(c);
127     if((c = a->type - b->type) != 0)
128         return(c);
129     return(memcmp(a->data, b->data, a->len));
130 }
131
132 static const char *formatsrc(const struct source *src)
133 {
134     static char buf[128];
135     struct in_addr a4;
136     struct in6_addr a6;
137     
138     switch(src->type) {
139     case AF_INET:
140         memcpy(&a4, src->data, sizeof(a4));
141         if(!inet_ntop(AF_INET, &a4, buf, sizeof(buf)))
142             return("<invalid ipv4>");
143         return(buf);
144     case AF_INET6:
145         memcpy(&a6, src->data, sizeof(a6));
146         if(!inet_ntop(AF_INET6, &a6, buf, sizeof(buf)))
147             return("<invalid ipv6>");
148         return(buf);
149     default:
150         return("<invalid source record>");
151     }
152 }
153
154 static void rehash(int nlen)
155 {
156     unsigned int i, o, n, m, pl, nl;
157     struct bucket **new, **old;
158     
159     old = buckets;
160     if(nlen <= SBUCKETS) {
161         nlen = SBUCKETS;
162         new = sbuckets;
163     } else {
164         new = smalloc(sizeof(*new) * (1 << nlen));
165     }
166     if(nlen == hashlen)
167         return;
168     memset(new, 0, sizeof(*new) * (1 << nlen));
169     assert(old != new);
170     pl = 1 << hashlen; nl = 1 << nlen; m = nl - 1;
171     for(i = 0; i < pl; i++) {
172         if(!old[i])
173             continue;
174         for(o = old[i]->id.hash & m, n = 0; n < nl; o = (o + 1) & m, n++) {
175             if(!new[o]) {
176                 new[o] = old[i];
177                 break;
178             }
179         }
180     }
181     if(old != sbuckets)
182         free(old);
183     buckets = new;
184     hashlen = nlen;
185 }
186
187 static struct bucket *hashget(const struct source *src)
188 {
189     unsigned int i, n, N, m;
190     struct bucket *bk;
191     
192     m = (N = (1 << hashlen)) - 1;
193     for(i = src->hash & m, n = 0; n < N; i = (i + 1) & m, n++) {
194         bk = buckets[i];
195         if(bk && !srccmp(&bk->id, src))
196             return(bk);
197     }
198     for(i = src->hash & m; buckets[i]; i = (i + 1) & m);
199     buckets[i] = bk = szmalloc(sizeof(*bk));
200     memcpy(&bk->id, src, sizeof(*src));
201     bk->last = bk->etime = now;
202     bk->thpos = -1;
203     bk->blocked = -1;
204     if(++nbuckets > (1 << (hashlen - 1)))
205         rehash(hashlen + 1);
206     return(bk);
207 }
208
209 static void hashdel(struct bucket *bk)
210 {
211     unsigned int i, o, p, n, N, m;
212     struct bucket *sb;
213     
214     m = (N = (1 << hashlen)) - 1;
215     for(i = bk->id.hash & m, n = 0; n < N; i = (i + 1) & m, n++) {
216         assert((sb = buckets[i]) != NULL);
217         if(!srccmp(&sb->id, &bk->id))
218             break;
219     }
220     assert(sb == bk);
221     buckets[i] = NULL;
222     for(o = (i + 1) & m; buckets[o] != NULL; o = (o + 1) & m) {
223         sb = buckets[o];
224         p = (sb->id.hash - i) & m;
225         if((p == 0) || (p > ((o - i) & m))) {
226             buckets[i] = sb;
227             buckets[o] = NULL;
228             i = o;
229         }
230     }
231     if(--nbuckets <= (1 << (hashlen - 3)))
232         rehash(hashlen - 1);
233 }
234
235 static void thraise(struct btime bt, int n)
236 {
237     int p;
238     
239     while(n > 0) {
240         p = (n - 1) >> 1;
241         if(timeheap.b[p].tm <= bt.tm)
242             break;
243         (timeheap.b[n] = timeheap.b[p]).bk->thpos = n;
244         n = p;
245     }
246     (timeheap.b[n] = bt).bk->thpos = n;
247 }
248
249 static void thlower(struct btime bt, int n)
250 {
251     int c1, c2, c;
252     
253     while(1) {
254         c2 = (c1 = (n << 1) + 1) + 1;
255         if(c1 >= timeheap.d)
256             break;
257         c = ((c2 < timeheap.d) && (timeheap.b[c2].tm < timeheap.b[c1].tm)) ? c2 : c1;
258         if(timeheap.b[c].tm > bt.tm)
259             break;
260         (timeheap.b[n] = timeheap.b[c]).bk->thpos = n;
261         n = c;
262     }
263     (timeheap.b[n] = bt).bk->thpos = n;
264 }
265
266 static void thadjust(struct btime bt, int n)
267 {
268     if((n > 0) && (timeheap.b[(n - 1) >> 1].tm > bt.tm))
269         thraise(bt, n);
270     else
271         thlower(bt, n);
272 }
273
274 static void freebucket(struct bucket *bk)
275 {
276     int i, n;
277     struct btime r;
278     
279     hashdel(bk);
280     if((n = bk->thpos) >= 0) {
281         r = timeheap.b[--timeheap.d];
282         if(n < timeheap.d)
283             thadjust(r, n);
284     }
285     for(i = 0; i < bk->brim.d; i++) {
286         freehthead(bk->brim.b[i].req);
287         close(bk->brim.b[i].fd);
288     }
289     buffree(bk->brim);
290     free(bk);
291 }
292
293 static void updbtime(struct bucket *bk)
294 {
295     double tm, tm2;
296     
297     tm = (bk->level == 0) ? (bk->etime + cf.retain) : (bk->last + (bk->level / cf.rate) + cf.retain);
298     if((bk->blocked > 0) && ((tm2 = bk->wtime + cf.warnrate) > tm))
299         tm = tm2;
300     
301     if((bk->brim.d > 0) && ((tm2 = bk->last + ((bk->level - cf.size) / cf.rate)) < tm))
302         tm = tm2;
303     if((bk->blocked > 0) && ((tm2 = bk->wtime + cf.warnrate) < tm))
304         tm = tm2;
305     
306     if(bk->thpos < 0) {
307         sizebuf(timeheap, ++timeheap.d);
308         thraise((struct btime){bk, tm}, timeheap.d - 1);
309     } else {
310         thadjust((struct btime){bk, tm}, bk->thpos);
311     }
312 }
313
314 static void tickbucket(struct bucket *bk)
315 {
316     double delta, ll;
317     
318     delta = now - bk->last;
319     bk->last = now;
320     ll = bk->level;
321     if((bk->level -= delta * cf.rate) < 0) {
322         if(ll > 0)
323             bk->etime = now + (bk->level / cf.rate);
324         bk->level = 0;
325     }
326     while((bk->brim.d > 0) && (bk->level < cf.size)) {
327         if(sendreq(child, bk->brim.b[0].req, bk->brim.b[0].fd)) {
328             flog(LOG_ERR, "ratequeue: could not pass request to child: %s", strerror(errno));
329             exit(1);
330         }
331         freehthead(bk->brim.b[0].req);
332         close(bk->brim.b[0].fd);
333         bufdel(bk->brim, 0);
334         bk->level += 1;
335     }
336     if((bk->blocked > 0) && (now - bk->wtime >= cf.warnrate)) {
337         flog(LOG_NOTICE, "ratequeue: blocked %i requests from %s", bk->blocked, formatsrc(&bk->id));
338         bk->blocked = 0;
339         bk->wtime = now;
340     }
341 }
342
343 static void checkbtime(struct bucket *bk)
344 {
345     tickbucket(bk);
346     if((bk->level == 0) && (now >= bk->etime + cf.retain) && (bk->blocked <= 0)) {
347         freebucket(bk);
348         return;
349     }
350     updbtime(bk);
351 }
352
353 static void serve(struct hthead *req, int fd)
354 {
355     struct source src;
356     struct bucket *bk;
357     
358     now = rtime();
359     src = reqsource(req);
360     bk = hashget(&src);
361     tickbucket(bk);
362     if(bk->level < cf.size) {
363         bk->level += 1;
364         if(sendreq(child, req, fd)) {
365             flog(LOG_ERR, "ratequeue: could not pass request to child: %s", strerror(errno));
366             exit(1);
367         }
368         freehthead(req);
369         close(fd);
370     } else if(bk->brim.d < cf.brimsize) {
371         bufadd(bk->brim, ((struct waiting){.req = req, .fd = fd}));
372     } else {
373         if(bk->blocked < 0) {
374             flog(LOG_NOTICE, "ratequeue: blocking requests from %s", formatsrc(&bk->id));
375             bk->blocked = 0;
376             bk->wtime = now;
377         }
378         simpleerror(fd, 429, "Too many requests", "Your client is being throttled for issuing too frequent requests.");
379         freehthead(req);
380         close(fd);
381         bk->blocked++;
382     }
383     updbtime(bk);
384 }
385
386 static int parseint(const char *str, int *dst)
387 {
388     long buf;
389     char *p;
390     
391     buf = strtol(str, &p, 0);
392     if((p == str) || *p)
393         return(-1);
394     *dst = buf;
395     return(0);
396 }
397
398 static int parsefloat(const char *str, double *dst)
399 {
400     double buf;
401     char *p;
402     
403     buf = strtod(str, &p);
404     if((p == str) || *p)
405         return(-1);
406     *dst = buf;
407     return(0);
408 }
409
410 static int readconf(char *path, struct config *buf)
411 {
412     FILE *fp;
413     struct cfstate *s;
414     int rv;
415     
416     if((fp = fopen(path, "r")) == NULL) {
417         flog(LOG_ERR, "ratequeue: %s: %s", path, strerror(errno));
418         return(-1);
419     }
420     *buf = defcfg;
421     s = mkcfparser(fp, path);
422     rv = -1;
423     while(1) {
424         getcfline(s);
425         if(!strcmp(s->argv[0], "eof")) {
426             break;
427         } else if(!strcmp(s->argv[0], "size")) {
428             if((s->argc < 2) || parsefloat(s->argv[1], &buf->size)) {
429                 flog(LOG_ERR, "%s:%i: missing or invalid `size' argument");
430                 goto err;
431             }
432         } else if(!strcmp(s->argv[0], "rate")) {
433             if((s->argc < 2) || parsefloat(s->argv[1], &buf->rate)) {
434                 flog(LOG_ERR, "%s:%i: missing or invalid `rate' argument");
435                 goto err;
436             }
437         } else if(!strcmp(s->argv[0], "brim")) {
438             if((s->argc < 2) || parseint(s->argv[1], &buf->brimsize)) {
439                 flog(LOG_ERR, "%s:%i: missing or invalid `brim' argument");
440                 goto err;
441             }
442         } else {
443             flog(LOG_WARNING, "%s:%i: unknown directive `%s'", s->file, s->lno, s->argv[0]);
444         }
445     }
446     rv = 0;
447 err:
448     freecfparser(s);
449     fclose(fp);
450     return(rv);
451 }
452
453 static void huphandler(int sig)
454 {
455     reload = 1;
456 }
457
458 static void usage(FILE *out)
459 {
460     fprintf(out, "usage: ratequeue [-h] [-s BUCKET-SIZE] [-r RATE] [-b BRIM-SIZE] PROGRAM [ARGS...]\n");
461 }
462
463 int main(int argc, char **argv)
464 {
465     int c, rv;
466     int fd;
467     struct hthead *req;
468     struct pollfd pfd;
469     double timeout;
470     char *cfname;
471     struct config cfbuf;
472     
473     cf = defcfg;
474     cfname = NULL;
475     while((c = getopt(argc, argv, "+hc:s:r:b:")) >= 0) {
476         switch(c) {
477         case 'h':
478             usage(stdout);
479             return(0);
480         case 'c':
481             cfname = optarg;
482             break;
483         case 's':
484             parsefloat(optarg, &cf.size);
485             break;
486         case 'r':
487             parsefloat(optarg, &cf.rate);
488             break;
489         case 'b':
490             parseint(optarg, &cf.brimsize);
491             break;
492         }
493     }
494     if(argc - optind < 1) {
495         usage(stderr);
496         return(1);
497     }
498     if(cfname) {
499         if(readconf(cfname, &cfbuf))
500             return(1);
501         cf = cfbuf;
502     }
503     if((child = stdmkchild(argv + optind, NULL, NULL)) < 0) {
504         flog(LOG_ERR, "ratequeue: could not fork child: %s", strerror(errno));
505         return(1);
506     }
507     sigaction(SIGHUP, &(struct sigaction){.sa_handler = huphandler}, NULL);
508     while(1) {
509         if(reload) {
510             if(cfname) {
511                 if(!readconf(cfname, &cfbuf))
512                     cf = cfbuf;
513             }
514             reload = 0;
515         }
516         now = rtime();
517         pfd = (struct pollfd){.fd = 0, .events = POLLIN};
518         timeout = (timeheap.d > 0) ? timeheap.b[0].tm : -1;
519         if((rv = poll(&pfd, 1, (timeout < 0) ? -1 : (int)((timeout + 0.1 - now) * 1000))) < 0) {
520             if(errno != EINTR) {
521                 flog(LOG_ERR, "ratequeue: error in poll: %s", strerror(errno));
522                 exit(1);
523             }
524         }
525         if(pfd.revents) {
526             if((fd = recvreq(0, &req)) < 0) {
527                 if(errno == EINTR)
528                     continue;
529                 if(errno != 0)
530                     flog(LOG_ERR, "recvreq: %s", strerror(errno));
531                 break;
532             }
533             serve(req, fd);
534         }
535         while((timeheap.d > 0) && ((now = rtime()) >= timeheap.b[0].tm))
536             checkbtime(timeheap.b[0].bk);
537     }
538     return(0);
539 }