Throw more informative error classes from perf.
[pdm.git] / pdm / sshsock.py
index 3a64122..c0ac300 100644 (file)
@@ -11,7 +11,32 @@ class sshsocket(object):
         args += [host]
         args += ["python3", "-m", "pdm.sshsock", path]
         self.proc = subprocess.Popen(args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, close_fds=True)
+        self.inbuf = bytearray()
         fcntl.fcntl(self.proc.stdout, fcntl.F_SETFL, fcntl.fcntl(self.proc.stdout, fcntl.F_GETFL) | os.O_NONBLOCK)
+        head = self.recv(5)
+        if head != b"SSOCK":
+            raise socket.error("unexpected reply from %s: %r" % (host, head))
+        head = self.recv(1)
+        if head == b"+":
+            buf = b""
+            while True:
+                r = self.recv(1)
+                if r == b"":
+                    raise socket.error("unexpected EOF in SSH socket stream")
+                elif r == b"\n":
+                    break
+                buf += r
+            return
+        elif head == b"-":
+            buf = b""
+            while True:
+                r = self.recv(1)
+                if r in {b"\n", b""}:
+                    break
+                buf += r
+            raise socket.error(buf.decode("utf-8"))
+        else:
+            raise socket.error("unexpected reply from %s: %r" % (host, head))
 
     def close(self):
         if self.proc is not None:
@@ -22,12 +47,22 @@ class sshsocket(object):
 
     def send(self, data, flags = 0):
         self.proc.stdin.write(data)
+        self.proc.stdin.flush()
         return len(data)
 
     def recv(self, buflen, flags = 0):
-        if (flags & socket.MSG_DONTWAIT) == 0:
-            select.select([self.proc.stdout], [], [])
-        return self.proc.stdout.read(buflen)
+        while len(self.inbuf) == 0:
+            try:
+                rv = os.read(self.proc.stdout.fileno(), max(4096, buflen))
+            except BlockingIOError:
+                if flags & socket.MSG_DONTWAIT:
+                    raise
+                select.select([self.proc.stdout], [], [])
+            else:
+                self.inbuf.extend(rv)
+        rv = bytes(self.inbuf[:buflen])
+        self.inbuf[:buflen] = b""
+        return rv
 
     def fileno(self):
         return self.proc.stdout.fileno()
@@ -39,7 +74,14 @@ def cli():
     fcntl.fcntl(sys.stdin.buffer, fcntl.F_SETFL, fcntl.fcntl(sys.stdin.buffer, fcntl.F_GETFL) | os.O_NONBLOCK)
     sk = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
     try:
-        sk.connect(sys.argv[1])
+        try:
+            sk.connect(sys.argv[1])
+        except socket.error as err:
+            sys.stdout.write("SSOCK-connect: %s\n" % err)
+            sys.stdout.flush()
+            return
+        sys.stdout.write("SSOCK+\n")
+        sys.stdout.flush()
         buf1 = b""
         buf2 = b""
         while True: