blob: cca0c2c07352bec37b64021394e4ac1850b7985f [file] [log] [blame]
Moshe0bb532c2018-02-26 13:39:57 +02001##############################################################################
2# Copyright 2018 EuropeanSoftwareMarketingLtd.
3# ===================================================================
4# Licensed under the ApacheLicense, Version2.0 (the"License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# software distributed under the License is distributed on an "AS IS" BASIS,
10# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11# See the License for the specific language governing permissions and limitations under
12# the License
13##############################################################################
14# vnftest comment: this is a modified copy of rally/rally/common/sshutils.py
15
16"""High level ssh library.
17
18Usage examples:
19
20Execute command and get output:
21
22 ssh = sshclient.SSH("root", "example.com", port=33)
23 status, stdout, stderr = ssh.execute("ps ax")
24 if status:
25 raise Exception("Command failed with non-zero status.")
26 print(stdout.splitlines())
27
28Execute command with huge output:
29
30 class PseudoFile(io.RawIOBase):
31 def write(chunk):
32 if "error" in chunk:
33 email_admin(chunk)
34
35 ssh = SSH("root", "example.com")
36 with PseudoFile() as p:
37 ssh.run("tail -f /var/log/syslog", stdout=p, timeout=False)
38
39Execute local script on remote side:
40
41 ssh = sshclient.SSH("user", "example.com")
42
43 with open("~/myscript.sh", "r") as stdin_file:
44 status, out, err = ssh.execute('/bin/sh -s "arg1" "arg2"',
45 stdin=stdin_file)
46
47Upload file:
48
49 ssh = SSH("user", "example.com")
50 # use rb for binary files
51 with open("/store/file.gz", "rb") as stdin_file:
52 ssh.run("cat > ~/upload/file.gz", stdin=stdin_file)
53
54Eventlet:
55
56 eventlet.monkey_patch(select=True, time=True)
57 or
58 eventlet.monkey_patch()
59 or
60 sshclient = eventlet.import_patched("vnftest.ssh")
61
62"""
63from __future__ import absolute_import
64import os
65import io
66import select
67import socket
68import time
69import re
70
71import logging
72
73import paramiko
74from chainmap import ChainMap
75from oslo_utils import encodeutils
76from scp import SCPClient
77import six
78from vnftest.common.utils import try_int
79
80
81def convert_key_to_str(key):
82 if not isinstance(key, (paramiko.RSAKey, paramiko.DSSKey)):
83 return key
84 k = io.StringIO()
85 key.write_private_key(k)
86 return k.getvalue()
87
88
89class SSHError(Exception):
90 pass
91
92
93class SSHTimeout(SSHError):
94 pass
95
96
97class SSH(object):
98 """Represent ssh connection."""
99
100 SSH_PORT = paramiko.config.SSH_PORT
101
102 @staticmethod
103 def gen_keys(key_filename, bit_count=2048):
104 rsa_key = paramiko.RSAKey.generate(bits=bit_count, progress_func=None)
105 rsa_key.write_private_key_file(key_filename)
106 print("Writing %s ..." % key_filename)
107 with open('.'.join([key_filename, "pub"]), "w") as pubkey_file:
108 pubkey_file.write(rsa_key.get_name())
109 pubkey_file.write(' ')
110 pubkey_file.write(rsa_key.get_base64())
111 pubkey_file.write('\n')
112
113 @staticmethod
114 def get_class():
115 # must return static class name, anything else refers to the calling class
116 # i.e. the subclass, not the superclass
117 return SSH
118
119 def __init__(self, user, host, port=None, pkey=None,
120 key_filename=None, password=None, name=None):
121 """Initialize SSH client.
122
123 :param user: ssh username
124 :param host: hostname or ip address of remote ssh server
125 :param port: remote ssh port
126 :param pkey: RSA or DSS private key string or file object
127 :param key_filename: private key filename
128 :param password: password
129 """
130 self.name = name
131 if name:
132 self.log = logging.getLogger(__name__ + '.' + self.name)
133 else:
134 self.log = logging.getLogger(__name__)
135
136 self.user = user
137 self.host = host
138 # everybody wants to debug this in the caller, do it here instead
139 self.log.debug("user:%s host:%s", user, host)
140
141 # we may get text port from YAML, convert to int
142 self.port = try_int(port, self.SSH_PORT)
143 self.pkey = self._get_pkey(pkey) if pkey else None
144 self.password = password
145 self.key_filename = key_filename
146 self._client = False
147 # paramiko loglevel debug will output ssh protocl debug
148 # we don't ever really want that unless we are debugging paramiko
149 # ssh issues
150 if os.environ.get("PARAMIKO_DEBUG", "").lower() == "true":
151 logging.getLogger("paramiko").setLevel(logging.DEBUG)
152 else:
153 logging.getLogger("paramiko").setLevel(logging.WARN)
154
155 @classmethod
156 def args_from_node(cls, node, overrides=None, defaults=None):
157 if overrides is None:
158 overrides = {}
159 if defaults is None:
160 defaults = {}
161 params = ChainMap(overrides, node, defaults)
162 return {
163 'user': params['user'],
164 'host': params['ip'],
165 'port': params.get('ssh_port', cls.SSH_PORT),
166 'pkey': params.get('pkey'),
167 'key_filename': params.get('key_filename'),
168 'password': params.get('password'),
169 'name': params.get('name'),
170 }
171
172 @classmethod
173 def from_node(cls, node, overrides=None, defaults=None):
174 return cls(**cls.args_from_node(node, overrides, defaults))
175
176 def _get_pkey(self, key):
177 if isinstance(key, six.string_types):
178 key = six.moves.StringIO(key)
179 errors = []
180 for key_class in (paramiko.rsakey.RSAKey, paramiko.dsskey.DSSKey):
181 try:
182 return key_class.from_private_key(key)
183 except paramiko.SSHException as e:
184 errors.append(e)
185 raise SSHError("Invalid pkey: %s" % (errors))
186
187 @property
188 def is_connected(self):
189 return bool(self._client)
190
191 def _get_client(self):
192 if self.is_connected:
193 return self._client
194 try:
195 self._client = paramiko.SSHClient()
196 self._client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
197 self._client.connect(self.host, username=self.user,
198 port=self.port, pkey=self.pkey,
199 key_filename=self.key_filename,
200 password=self.password,
201 allow_agent=False, look_for_keys=False,
202 timeout=1)
203 return self._client
204 except Exception as e:
205 message = ("Exception %(exception_type)s was raised "
206 "during connect. Exception value is: %(exception)r")
207 self._client = False
208 raise SSHError(message % {"exception": e,
209 "exception_type": type(e)})
210
211 def _make_dict(self):
212 return {
213 'user': self.user,
214 'host': self.host,
215 'port': self.port,
216 'pkey': self.pkey,
217 'key_filename': self.key_filename,
218 'password': self.password,
219 'name': self.name,
220 }
221
222 def copy(self):
223 return self.get_class()(**self._make_dict())
224
225 def close(self):
226 if self._client:
227 self._client.close()
228 self._client = False
229
230 def run(self, cmd, stdin=None, stdout=None, stderr=None,
231 raise_on_error=True, timeout=3600,
232 keep_stdin_open=False, pty=False):
233 """Execute specified command on the server.
234
235 :param cmd: Command to be executed.
236 :type cmd: str
237 :param stdin: Open file or string to pass to stdin.
238 :param stdout: Open file to connect to stdout.
239 :param stderr: Open file to connect to stderr.
240 :param raise_on_error: If False then exit code will be return. If True
241 then exception will be raized if non-zero code.
242 :param timeout: Timeout in seconds for command execution.
243 Default 1 hour. No timeout if set to 0.
244 :param keep_stdin_open: don't close stdin on empty reads
245 :type keep_stdin_open: bool
246 :param pty: Request a pseudo terminal for this connection.
247 This allows passing control characters.
248 Default False.
249 :type pty: bool
250 """
251
252 client = self._get_client()
253
254 if isinstance(stdin, six.string_types):
255 stdin = six.moves.StringIO(stdin)
256
257 return self._run(client, cmd, stdin=stdin, stdout=stdout,
258 stderr=stderr, raise_on_error=raise_on_error,
259 timeout=timeout,
260 keep_stdin_open=keep_stdin_open, pty=pty)
261
262 def _run(self, client, cmd, stdin=None, stdout=None, stderr=None,
263 raise_on_error=True, timeout=3600,
264 keep_stdin_open=False, pty=False):
265
266 transport = client.get_transport()
267 session = transport.open_session()
268 if pty:
269 session.get_pty()
270 session.exec_command(cmd)
271 start_time = time.time()
272
273 # encode on transmit, decode on receive
274 data_to_send = encodeutils.safe_encode("", incoming='utf-8')
275 stderr_data = None
276
277 # If we have data to be sent to stdin then `select' should also
278 # check for stdin availability.
279 if stdin and not stdin.closed:
280 writes = [session]
281 else:
282 writes = []
283
284 while True:
285 # Block until data can be read/write.
286 r, w, e = select.select([session], writes, [session], 1)
287
288 if session.recv_ready():
289 data = encodeutils.safe_decode(session.recv(4096), 'utf-8')
290 self.log.debug("stdout: %r", data)
291 if stdout is not None:
292 stdout.write(data)
293 continue
294
295 if session.recv_stderr_ready():
296 stderr_data = encodeutils.safe_decode(
297 session.recv_stderr(4096), 'utf-8')
298 self.log.debug("stderr: %r", stderr_data)
299 if stderr is not None:
300 stderr.write(stderr_data)
301 continue
302
303 if session.send_ready():
304 if stdin is not None and not stdin.closed:
305 if not data_to_send:
306 stdin_txt = stdin.read(4096)
307 if stdin_txt is None:
308 stdin_txt = ''
309 data_to_send = encodeutils.safe_encode(
310 stdin_txt, incoming='utf-8')
311 if not data_to_send:
312 # we may need to keep stdin open
313 if not keep_stdin_open:
314 stdin.close()
315 session.shutdown_write()
316 writes = []
317 if data_to_send:
318 sent_bytes = session.send(data_to_send)
319 # LOG.debug("sent: %s" % data_to_send[:sent_bytes])
320 data_to_send = data_to_send[sent_bytes:]
321
322 if session.exit_status_ready():
323 break
324
325 if timeout and (time.time() - timeout) > start_time:
326 args = {"cmd": cmd, "host": self.host}
327 raise SSHTimeout("Timeout executing command "
328 "'%(cmd)s' on host %(host)s" % args)
329 if e:
330 raise SSHError("Socket error.")
331
332 exit_status = session.recv_exit_status()
333 if exit_status != 0 and raise_on_error:
334 fmt = "Command '%(cmd)s' failed with exit_status %(status)d."
335 details = fmt % {"cmd": cmd, "status": exit_status}
336 if stderr_data:
337 details += " Last stderr data: '%s'." % stderr_data
338 raise SSHError(details)
339 return exit_status
340
341 def execute(self, cmd, stdin=None, timeout=3600):
342 """Execute the specified command on the server.
343
344 :param cmd: Command to be executed.
345 :param stdin: Open file to be sent on process stdin.
346 :param timeout: Timeout for execution of the command.
347
348 :returns: tuple (exit_status, stdout, stderr)
349 """
350 stdout = six.moves.StringIO()
351 stderr = six.moves.StringIO()
352
353 exit_status = self.run(cmd, stderr=stderr,
354 stdout=stdout, stdin=stdin,
355 timeout=timeout, raise_on_error=False)
356 stdout.seek(0)
357 stderr.seek(0)
358 return exit_status, stdout.read(), stderr.read()
359
360 def wait(self, timeout=120, interval=1):
361 """Wait for the host will be available via ssh."""
362 start_time = time.time()
363 while True:
364 try:
365 return self.execute("uname")
366 except (socket.error, SSHError) as e:
367 self.log.debug("Ssh is still unavailable: %r", e)
368 time.sleep(interval)
369 if time.time() > (start_time + timeout):
370 raise SSHTimeout("Timeout waiting for '%s'", self.host)
371
372 def put(self, files, remote_path=b'.', recursive=False):
373 client = self._get_client()
374
375 with SCPClient(client.get_transport()) as scp:
376 scp.put(files, remote_path, recursive)
377
378 def get(self, remote_path, local_path='/tmp/', recursive=True):
379 client = self._get_client()
380
381 with SCPClient(client.get_transport()) as scp:
382 scp.get(remote_path, local_path, recursive)
383
384 # keep shell running in the background, e.g. screen
385 def send_command(self, command):
386 client = self._get_client()
387 client.exec_command(command, get_pty=True)
388
389 def _put_file_sftp(self, localpath, remotepath, mode=None):
390 client = self._get_client()
391
392 with client.open_sftp() as sftp:
393 sftp.put(localpath, remotepath)
394 if mode is None:
395 mode = 0o777 & os.stat(localpath).st_mode
396 sftp.chmod(remotepath, mode)
397
398 TILDE_EXPANSIONS_RE = re.compile("(^~[^/]*/)?(.*)")
399
400 def _put_file_shell(self, localpath, remotepath, mode=None):
401 # quote to stop wordpslit
402 tilde, remotepath = self.TILDE_EXPANSIONS_RE.match(remotepath).groups()
403 if not tilde:
404 tilde = ''
405 cmd = ['cat > %s"%s"' % (tilde, remotepath)]
406 if mode is not None:
407 # use -- so no options
408 cmd.append('chmod -- 0%o %s"%s"' % (mode, tilde, remotepath))
409
410 with open(localpath, "rb") as localfile:
411 # only chmod on successful cat
412 self.run("&& ".join(cmd), stdin=localfile)
413
414 def put_file(self, localpath, remotepath, mode=None):
415 """Copy specified local file to the server.
416
417 :param localpath: Local filename.
418 :param remotepath: Remote filename.
419 :param mode: Permissions to set after upload
420 """
421 try:
422 self._put_file_sftp(localpath, remotepath, mode=mode)
423 except (paramiko.SSHException, socket.error):
424 self._put_file_shell(localpath, remotepath, mode=mode)
425
426 def put_file_obj(self, file_obj, remotepath, mode=None):
427 client = self._get_client()
428
429 with client.open_sftp() as sftp:
430 sftp.putfo(file_obj, remotepath)
431 if mode is not None:
432 sftp.chmod(remotepath, mode)
433
434 def get_file_obj(self, remotepath, file_obj):
435 client = self._get_client()
436
437 with client.open_sftp() as sftp:
438 sftp.getfo(remotepath, file_obj)
439
440
441class AutoConnectSSH(SSH):
442
443 # always wait or we will get OpenStack SSH errors
444 def __init__(self, user, host, port=None, pkey=None,
445 key_filename=None, password=None, name=None, wait=True):
446 super(AutoConnectSSH, self).__init__(user, host, port, pkey, key_filename, password, name)
447 self._wait = wait
448
449 def _make_dict(self):
450 data = super(AutoConnectSSH, self)._make_dict()
451 data.update({
452 'wait': self._wait
453 })
454 return data
455
456 def _connect(self):
457 if not self.is_connected:
458 self._get_client()
459 if self._wait:
460 self.wait()
461
462 def drop_connection(self):
463 """ Don't close anything, just force creation of a new client """
464 self._client = False
465
466 def execute(self, cmd, stdin=None, timeout=3600):
467 self._connect()
468 return super(AutoConnectSSH, self).execute(cmd, stdin, timeout)
469
470 def run(self, cmd, stdin=None, stdout=None, stderr=None,
471 raise_on_error=True, timeout=3600,
472 keep_stdin_open=False, pty=False):
473 self._connect()
474 return super(AutoConnectSSH, self).run(cmd, stdin, stdout, stderr, raise_on_error,
475 timeout, keep_stdin_open, pty)
476
477 def put(self, files, remote_path=b'.', recursive=False):
478 self._connect()
479 return super(AutoConnectSSH, self).put(files, remote_path, recursive)
480
481 def put_file(self, local_path, remote_path, mode=None):
482 self._connect()
483 return super(AutoConnectSSH, self).put_file(local_path, remote_path, mode)
484
485 def put_file_obj(self, file_obj, remote_path, mode=None):
486 self._connect()
487 return super(AutoConnectSSH, self).put_file_obj(file_obj, remote_path, mode)
488
489 def get_file_obj(self, remote_path, file_obj):
490 self._connect()
491 return super(AutoConnectSSH, self).get_file_obj(remote_path, file_obj)
492
493 @staticmethod
494 def get_class():
495 # must return static class name, anything else refers to the calling class
496 # i.e. the subclass, not the superclass
497 return AutoConnectSSH