Fix sshsocket input buffering bugs.
[pdm.git] / pdm / sshsock.py
1 import sys, os
2 import subprocess, socket, fcntl, select
3
4 class sshsocket(object):
5     def __init__(self, host, path, user = None, port = None):
6         args = ["ssh"]
7         if user is not None:
8             args += ["-u", str(user)]
9         if port is not None:
10             args += ["-p", str(int(port))]
11         args += [host]
12         args += ["python3", "-m", "pdm.sshsock", path]
13         self.proc = subprocess.Popen(args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, close_fds=True)
14         self.inbuf = bytearray()
15         fcntl.fcntl(self.proc.stdout, fcntl.F_SETFL, fcntl.fcntl(self.proc.stdout, fcntl.F_GETFL) | os.O_NONBLOCK)
16         head = self.recv(5)
17         if head != b"SSOCK":
18             raise socket.error("unexpected reply from %s: %r" % (host, head))
19         head = self.recv(1)
20         if head == b"+":
21             buf = b""
22             while True:
23                 r = self.recv(1)
24                 if r == b"":
25                     raise socket.error("unexpected EOF in SSH socket stream")
26                 elif r == b"\n":
27                     break
28                 buf += r
29             return
30         elif head == b"-":
31             buf = b""
32             while True:
33                 r = self.recv(1)
34                 if r in {b"\n", b""}:
35                     break
36                 buf += r
37             raise socket.error(buf.decode("utf-8"))
38         else:
39             raise socket.error("unexpected reply from %s: %r" % (host, head))
40
41     def close(self):
42         if self.proc is not None:
43             self.proc.stdin.close()
44             self.proc.stdout.close()
45             self.proc.wait()
46             self.proc = None
47
48     def send(self, data, flags = 0):
49         self.proc.stdin.write(data)
50         self.proc.stdin.flush()
51         return len(data)
52
53     def recv(self, buflen, flags = 0):
54         while len(self.inbuf) == 0:
55             try:
56                 rv = os.read(self.proc.stdout.fileno(), max(4096, buflen))
57             except BlockingIOError:
58                 if flags & socket.MSG_DONTWAIT:
59                     raise
60                 select.select([self.proc.stdout], [], [])
61             else:
62                 self.inbuf.extend(rv)
63         rv = bytes(self.inbuf[:buflen])
64         self.inbuf[:buflen] = b""
65         return rv
66
67     def fileno(self):
68         return self.proc.stdout.fileno()
69
70     def __del__(self):
71         self.close()
72
73 def cli():
74     fcntl.fcntl(sys.stdin.buffer, fcntl.F_SETFL, fcntl.fcntl(sys.stdin.buffer, fcntl.F_GETFL) | os.O_NONBLOCK)
75     sk = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
76     try:
77         try:
78             sk.connect(sys.argv[1])
79         except socket.error as err:
80             sys.stdout.write("SSOCK-connect: %s\n" % err)
81             sys.stdout.flush()
82             return
83         sys.stdout.write("SSOCK+\n")
84         sys.stdout.flush()
85         buf1 = b""
86         buf2 = b""
87         while True:
88             wfd = []
89             if buf1: wfd.append(sk)
90             if buf2: wfd.append(sys.stdout.buffer)
91             rfd, wfd, efd = select.select([sk, sys.stdin.buffer], wfd, [])
92             if sk in rfd:
93                 ret = sk.recv(65536)
94                 if ret == b"":
95                     break
96                 else:
97                     buf2 += ret
98             if sys.stdin.buffer in rfd:
99                 ret = sys.stdin.buffer.read()
100                 if ret == b"":
101                     break
102                 else:
103                     buf1 = ret
104             if sk in wfd:
105                 ret = sk.send(buf1)
106                 buf1 = buf1[ret:]
107             if sys.stdout.buffer in wfd:
108                 sys.stdout.buffer.write(buf2)
109                 sys.stdout.buffer.flush()
110                 buf2 = b""
111     finally:
112         sk.close()
113
114 if __name__ == "__main__":
115     cli()