acmecert: Reorganized a bit.
[utils.git] / acmecert
index 9b30c44..4c56ad8 100755 (executable)
--- a/acmecert
+++ b/acmecert
@@ -1,17 +1,16 @@
 #!/usr/bin/python3
 
-import sys, os, getopt, binascii, json, pprint, signal, time
+#### ACME client (only http-01 challenges supported thus far)
+
+import sys, os, getopt, binascii, json, pprint, signal, time, threading
 import urllib.request
 import Crypto.PublicKey.RSA, Crypto.Random, Crypto.Hash.SHA256, Crypto.Signature.PKCS1_v1_5
 
-service = "https://acme-v02.api.letsencrypt.org/directory"
-_directory = None
-def directory():
-    global _directory
-    if _directory is None:
-        with req(service) as resp:
-            _directory = json.loads(resp.read().decode("utf-8"))
-    return _directory
+### General utilities
+
+class msgerror(Exception):
+    def report(self, out):
+        out.write("acmecert: undefined error\n")
 
 def base64url(dat):
     return binascii.b2a_base64(dat).decode("us-ascii").translate({43: 45, 47: 95, 61: None}).strip()
@@ -21,34 +20,29 @@ def ebignum(num):
     if len(h) % 2 == 1: h = "0" + h
     return base64url(binascii.a2b_hex(h))
 
-def getnonce():
-    with urllib.request.urlopen(directory()["newNonce"]) as resp:
-        resp.read()
-        return resp.headers["Replay-Nonce"]
+class maybeopen(object):
+    def __init__(self, name, mode):
+        if name == "-":
+            self.opened = False
+            if mode == "r":
+                self.fp = sys.stdin
+            elif mode == "w":
+                self.fp = sys.stdout
+            else:
+                raise ValueError(mode)
+        else:
+            self.opened = True
+            self.fp = open(name, mode)
 
-def req(url, data=None, ctype=None, headers={}, method=None, **kws):
-    if data is not None and not isinstance(data, bytes):
-        data = json.dumps(data).encode("utf-8")
-        ctype = "application/jose+json"
-    req = urllib.request.Request(url, data=data, method=method)
-    for hnam, hval in headers.items():
-        req.add_header(hnam, hval)
-    if ctype is not None:
-        req.add_header("Content-Type", ctype)
-    return urllib.request.urlopen(req)
+    def __enter__(self):
+        return self.fp
 
-def jreq(url, data, auth):
-    authdata = {"alg": "RS256", "url": url, "nonce": getnonce()}
-    authdata.update(auth.authdata())
-    authdata = base64url(json.dumps(authdata).encode("us-ascii"))
-    if data is None:
-        data = ""
-    else:
-        data = base64url(json.dumps(data).encode("us-ascii"))
-    seal = base64url(auth.sign(("%s.%s" % (authdata, data)).encode("us-ascii")))
-    enc = {"protected": authdata, "payload": data, "signature": seal}
-    with req(url, data=enc) as resp:
-        return json.loads(resp.read().decode("utf-8")), resp.headers
+    def __exit__(self, *excinfo):
+        if self.opened:
+            self.fp.close()
+        return False
+
+### Crypto utilities
 
 class certificate(object):
     @property
@@ -123,6 +117,86 @@ class signreq(object):
         self.data = fp.read()
         return self
 
+### Somewhat general request utilities
+
+def getnonce():
+    with urllib.request.urlopen(directory()["newNonce"]) as resp:
+        resp.read()
+        return resp.headers["Replay-Nonce"]
+
+def req(url, data=None, ctype=None, headers={}, method=None, **kws):
+    if data is not None and not isinstance(data, bytes):
+        data = json.dumps(data).encode("utf-8")
+        ctype = "application/jose+json"
+    req = urllib.request.Request(url, data=data, method=method)
+    for hnam, hval in headers.items():
+        req.add_header(hnam, hval)
+    if ctype is not None:
+        req.add_header("Content-Type", ctype)
+    return urllib.request.urlopen(req)
+
+class problem(msgerror):
+    def __init__(self, code, data, *args, url=None, **kw):
+        super().__init__(*args, **kw)
+        self.code = code
+        self.data = data
+        self.url = url
+        if not isinstance(data, dict):
+            raise ValueError("unexpected problem object type: %r" % (data,))
+
+    @property
+    def type(self):
+        return self.data.get("type", "about:blank")
+    @property
+    def title(self):
+        return self.data.get("title")
+    @property
+    def detail(self):
+        return self.data.get("detail")
+
+    def report(self, out):
+        extra = None
+        if self.title is None:
+            msg = self.detail
+            if "\n" in msg:
+                extra, msg = msg, None
+        else:
+            msg = self.title
+            extra = self.detail
+        if msg is None:
+            msg = self.data.get("type")
+        if msg is not None:
+            out.write("acemcert: %s: %s\n" % (
+                ("remote service error" if self.url is None else self.url),
+                ("unspecified error" if msg is None else msg)))
+        if extra is not None:
+            out.write("%s\n" % (extra,))
+
+    @classmethod
+    def read(cls, err, **kw):
+        self = cls(err.code, json.load(err), **kw)
+        return self
+
+def jreq(url, data, auth):
+    authdata = {"alg": "RS256", "url": url, "nonce": getnonce()}
+    authdata.update(auth.authdata())
+    authdata = base64url(json.dumps(authdata).encode("us-ascii"))
+    if data is None:
+        data = ""
+    else:
+        data = base64url(json.dumps(data).encode("us-ascii"))
+    seal = base64url(auth.sign(("%s.%s" % (authdata, data)).encode("us-ascii")))
+    enc = {"protected": authdata, "payload": data, "signature": seal}
+    try:
+        with req(url, data=enc) as resp:
+            return json.load(resp), resp.headers
+    except urllib.error.HTTPError as exc:
+        if exc.headers["Content-Type"] == "application/problem+json":
+            raise problem.read(exc, url=url)
+        raise
+
+## Authentication
+
 class jwkauth(object):
     def __init__(self, key):
         self.key = key
@@ -170,27 +244,19 @@ class account(object):
         key = Crypto.PublicKey.RSA.importKey(fp.read())
         return cls(uri, key)
 
-class htconfig(object):
-    def __init__(self):
-        self.roots = {}
+### ACME protocol
 
-    @classmethod
-    def read(cls, fp):
-        self = cls()
-        for ln in fp:
-            words = ln.split()
-            if len(words) < 1 or ln[0] == '#':
-                continue
-            if words[0] == "root":
-                self.roots[words[1]] = words[2]
-            else:
-                sys.stderr.write("acmecert: warning: unknown htconfig directive: %s\n" % (words[0]))
-        return self
+service = "https://acme-v02.api.letsencrypt.org/directory"
+_directory = None
+def directory():
+    global _directory
+    if _directory is None:
+        with req(service) as resp:
+            _directory = json.load(resp)
+    return _directory
 
 def register(keysize=4096):
     key = Crypto.PublicKey.RSA.generate(keysize, Crypto.Random.new().read)
-    # jwk = {"kty": "RSA", "e": ebignum(key.e), "n": ebignum(key.n)}
-    # cjwk = json.dumps(jwk, separators=(',', ':'), sort_keys=True)
     data, headers = jreq(directory()["newAccount"], {"termsOfServiceAgreed": True}, jwkauth(key))
     return account(headers["Location"], key)
     
@@ -206,6 +272,47 @@ def httptoken(acct, ch):
     khash = base64url(dig.digest())
     return ch["token"], ("%s.%s" % (ch["token"], khash))
 
+def finalize(acct, csr, orderid):
+    order, headers = jreq(orderid, None, acct)
+    if order["status"] == "valid":
+        pass
+    elif order["status"] == "ready":
+        jreq(order["finalize"], {"csr": base64url(csr.der())}, acct)
+        for n in range(30):
+            resp, headers = jreq(orderid, None, acct)
+            if resp["status"] == "processing":
+                time.sleep(2)
+            elif resp["status"] == "valid":
+                order = resp
+                break
+            else:
+                raise Exception("unexpected order status when finalizing: %s" % resp["status"])
+        else:
+            raise Exception("order finalization timed out")
+    else:
+        raise Exception("unexpected order state when finalizing: %s" % (order["status"],))
+    with req(order["certificate"]) as resp:
+        return resp.read().decode("us-ascii")
+
+## http-01 challenge
+
+class htconfig(object):
+    def __init__(self):
+        self.roots = {}
+
+    @classmethod
+    def read(cls, fp):
+        self = cls()
+        for ln in fp:
+            words = ln.split()
+            if len(words) < 1 or ln[0] == '#':
+                continue
+            if words[0] == "root":
+                self.roots[words[1]] = words[2]
+            else:
+                sys.stderr.write("acmecert: warning: unknown htconfig directive: %s\n" % (words[0]))
+        return self
+
 def authorder(acct, htconf, orderid):
     order, headers = jreq(orderid, None, acct)
     valid = False
@@ -248,6 +355,13 @@ def authorder(acct, htconf, orderid):
                     resp, headers = jreq(ch["url"], {}, acct)
                     if resp["status"] == "processing":
                         time.sleep(2)
+                    elif resp["status"] == "pending":
+                        # I don't think this should happen, but it
+                        # does. LE bug? Anyway, just retry.
+                        if n < 5:
+                            time.sleep(2)
+                        else:
+                            break
                     elif resp["status"] == "valid":
                         break
                     else:
@@ -257,95 +371,164 @@ def authorder(acct, htconf, orderid):
             finally:
                 os.unlink(tokpath)
 
-def finalize(acct, csr, orderid):
-    order, headers = jreq(orderid, None, acct)
-    if order["status"] == "valid":
-        pass
-    elif order["status"] == "ready":
-        jreq(order["finalize"], {"csr": base64url(csr.der())}, acct)
-        for n in range(30):
-            resp, headers = jreq(orderid, None, acct)
-            if resp["status"] == "processing":
-                time.sleep(2)
-            elif resp["status"] == "valid":
-                order = resp
-                break
-            else:
-                raise Exception("unexpected order status when finalizing: %s" % resp["status"])
-        else:
-            raise Exception("order finalization timed out")
+### Invocation and commands
+
+invdata = threading.local()
+commands = {}
+
+class usageerr(msgerror):
+    def __init__(self):
+        self.cmd = invdata.cmd
+
+    def report(self, out):
+        out.write("%s\n" % (self.cmd.__doc__,))
+
+## User commands
+
+def cmd_reg(args):
+    "usage: acmecert reg [OUTPUT-FILE]"
+    acct = register()
+    os.umask(0o077)
+    with maybeopen(args[1] if len(args) > 1 else "-", "w") as fp:
+        acct.write(fp)
+commands["reg"] = cmd_reg
+
+def cmd_validate_acct(args):
+    "usage: acmecert validate-acct ACCOUNT-FILE"
+    if len(args) < 2: raise usageerr()
+    with maybeopen(args[1], "r") as fp:
+        account.read(fp).validate()
+commands["validate-acct"] = cmd_validate_acct
+
+def cmd_acct_info(args):
+    "usage: acmecert acct-info ACCOUNT-FILE"
+    if len(args) < 2: raise usageerr()
+    with maybeopen(args[1], "r") as fp:
+        pprint.pprint(account.read(fp).getinfo())
+commands["acct-info"] = cmd_acct_info
+
+def cmd_order(args):
+    "usage: acmecert order ACCOUNT-FILE CSR [OUTPUT-FILE]"
+    if len(args) < 3: raise usageerr()
+    with maybeopen(args[1], "r") as fp:
+        acct = account.read(fp)
+    with maybeopen(args[2], "r") as fp:
+        csr = signreq.read(fp)
+    order = mkorder(acct, csr)
+    with maybeopen(args[3] if len(args) > 3 else "-", "w") as fp:
+        fp.write("%s\n" % (order["acmecert.location"]))
+commands["order"] = cmd_order
+
+def cmd_http_auth(args):
+    "usage: acmecert http-auth ACCOUNT-FILE HTTP-CONFIG {ORDER-ID|ORDER-FILE}"
+    if len(args) < 4: raise usageerr()
+    with maybeopen(args[1], "r") as fp:
+        acct = account.read(fp)
+    with maybeopen(args[2], "r") as fp:
+        htconf = htconfig.read(fp)
+    if "://" in args[3]:
+        orderid = args[3]
     else:
-        raise Exception("unexpected order state when finalizing: %s" % (order["status"],))
-    with req(order["certificate"]) as resp:
-        return resp.read().decode("us-ascii")
+        with maybeopen(args[3], "r") as fp:
+            orderid = fp.readline().strip()
+    authorder(acct, htconf, orderid)
+commands["http-auth"] = cmd_http_auth
+
+def cmd_get(args):
+    "usage: acmecert get ACCOUNT-FILE CSR {ORDER-ID|ORDER-FILE}"
+    if len(args) < 4: raise usageerr()
+    with maybeopen(args[1], "r") as fp:
+        acct = account.read(fp)
+    with maybeopen(args[2], "r") as fp:
+        csr = signreq.read(fp)
+    if "://" in args[3]:
+        orderid = args[3]
+    else:
+        with maybeopen(args[3], "r") as fp:
+            orderid = fp.readline().strip()
+    sys.stdout.write(finalize(acct, csr, orderid))
+commands["get"] = cmd_get
+
+def cmd_http_order(args):
+    "usage: acmecert http-order ACCOUNT-FILE CSR HTTP-CONFIG [OUTPUT-FILE]"
+    if len(args) < 4: raise usageerr()
+    with maybeopen(args[1], "r") as fp:
+        acct = account.read(fp)
+    with maybeopen(args[2], "r") as fp:
+        csr = signreq.read(fp)
+    with maybeopen(args[3], "r") as fp:
+        htconf = htconfig.read(fp)
+    orderid = mkorder(acct, csr)["acmecert.location"]
+    authorder(acct, htconf, orderid)
+    with maybeopen(args[4] if len(args) > 4 else "-", "w") as fp:
+        fp.write(finalize(acct, csr, orderid))
+commands["http-order"] = cmd_http_order
+
+def cmd_check_cert(args):
+    "usage: acmecert check-cert CERT-FILE TIME-SPEC"
+    if len(args) < 3: raise usageerr()
+    with maybeopen(args[1], "r") as fp:
+        crt = certificate.read(fp)
+    sys.exit(1 if crt.expiring(args[2]) else 0)
+commands["check-cert"] = cmd_check_cert
+
+def cmd_directory(args):
+    "usage: acmecert directory"
+    pprint.pprint(directory())
+commands["directory"] = cmd_directory
+
+## Main invocation
 
 def usage(out):
-    out.write("usage: acmecert [-h] [-D SERVICE]\n")
+    out.write("usage: acmecert [-D SERVICE] COMMAND [ARGS...]\n")
+    out.write("       acmecert -h [COMMAND]\n")
+    buf =     "       COMMAND is any of: "
+    f = True
+    for cmd in commands:
+        if len(buf) + len(cmd) > 70:
+            out.write("%s\n" % (buf,))
+            buf =     "           "
+            f = True
+        if not f:
+            buf += ", "
+        buf += cmd
+        f = False
+    if not f:
+        out.write("%s\n" % (buf,))
 
 def main(argv):
     global service
     opts, args = getopt.getopt(argv[1:], "hD:")
     for o, a in opts:
         if o == "-h":
-            usage(sys.stdout)
+            if len(args) > 0:
+                cmd = commands.get(args[0])
+                if cmd is None:
+                    sys.stderr.write("acmecert: unknown command: %s\n" % (args[0],))
+                    sys.exit(1)
+                sys.stdout.write("%s\n" % (cmd.__doc__,))
+            else:
+                usage(sys.stdout)
             sys.exit(0)
         elif o == "-D":
             service = a
     if len(args) < 1:
         usage(sys.stderr)
         sys.exit(1)
-    if args[0] == "reg":
-        register().write(sys.stdout)
-    elif args[0] == "validate-acct":
-        with open(args[1], "r") as fp:
-            account.read(fp).validate()
-    elif args[0] == "acctinfo":
-        with open(args[1], "r") as fp:
-            pprint.pprint(account.read(fp).getinfo())
-    elif args[0] == "order":
-        with open(args[1], "r") as fp:
-            acct = account.read(fp)
-        with open(args[2], "r") as fp:
-            csr = signreq.read(fp)
-        order = mkorder(acct, csr)
-        with open(args[3], "w") as fp:
-            fp.write("%s\n" % (order["acmecert.location"]))
-    elif args[0] == "http-auth":
-        with open(args[1], "r") as fp:
-            acct = account.read(fp)
-        with open(args[2], "r") as fp:
-            htconf = htconfig.read(fp)
-        with open(args[3], "r") as fp:
-            orderid = fp.readline().strip()
-        authorder(acct, htconf, orderid)
-    elif args[0] == "get":
-        with open(args[1], "r") as fp:
-            acct = account.read(fp)
-        with open(args[2], "r") as fp:
-            csr = signreq.read(fp)
-        with open(args[3], "r") as fp:
-            orderid = fp.readline().strip()
-        sys.stdout.write(finalize(acct, csr, orderid))
-    elif args[0] == "http-order":
-        with open(args[1], "r") as fp:
-            acct = account.read(fp)
-        with open(args[2], "r") as fp:
-            csr = signreq.read(fp)
-        with open(args[3], "r") as fp:
-            htconf = htconfig.read(fp)
-        orderid = mkorder(acct, csr)["acmecert.location"]
-        authorder(acct, htconf, orderid)
-        sys.stdout.write(finalize(acct, csr, orderid))
-    elif args[0] == "check-cert":
-        with open(args[1], "r") as fp:
-            crt = certificate.read(fp)
-        sys.exit(1 if crt.expiring(args[2]) else 0)
-    elif args[0] == "directory":
-        pprint.pprint(directory())
-    else:
+    cmd = commands.get(args[0])
+    if cmd is None:
         sys.stderr.write("acmecert: unknown command: %s\n" % (args[0],))
         usage(sys.stderr)
         sys.exit(1)
+    try:
+        try:
+            invdata.cmd = cmd
+            cmd(args)
+        finally:
+            invdata.cmd = None
+    except msgerror as exc:
+        exc.report(sys.stderr)
+        sys.exit(1)
 
 if __name__ == "__main__":
     try: