blob: 0887a765ba51d39ecf84ee3ba26db977cdbea9d7 [file] [log] [blame]
Stefan Kobzab1899332015-12-23 17:00:10 +01001# Copyright (c) 2015 Cisco and/or its affiliates.
2# Licensed under the Apache License, Version 2.0 (the "License");
3# you may not use this file except in compliance with the License.
4# You may obtain a copy of the License at:
5#
6# http://www.apache.org/licenses/LICENSE-2.0
7#
8# Unless required by applicable law or agreed to in writing, software
9# distributed under the License is distributed on an "AS IS" BASIS,
10# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11# See the License for the specific language governing permissions and
12# limitations under the License.
13import paramiko
14from scp import SCPClient
15from time import time
16from robot.api import logger
17
18__all__ = ["exec_cmd"]
19
20# TODO: Attempt to recycle SSH connections
21# TODO: load priv key
22
23class SSH(object):
24
25 __MAX_RECV_BUF = 10*1024*1024
26 __existing_connections = {}
27
28 def __init__(self):
29 self._ssh = paramiko.SSHClient()
30 self._ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
31 self._hostname = None
32
33 def _node_hash(self, node):
34 return hash(frozenset([node['host'], node['port']]))
35
36 def connect(self, node):
37 """Connect to node prior to running exec_command or scp.
38
39 If there already is a connection to the node, this method reuses it.
40 """
41 self._hostname = node['host']
42 node_hash = self._node_hash(node)
43 if node_hash in self.__existing_connections:
44 self._ssh = self.__existing_connections[node_hash]
45 else:
46 start = time()
47 self._ssh.connect(node['host'], username=node['username'],
48 password=node['password'])
49 self.__existing_connections[node_hash] = self._ssh
50 logger.trace('connect took {} seconds'.format(time() - start))
51
52 def exec_command(self, cmd, timeout=10):
53 """Execute SSH command on a new channel on the connected Node.
54
55 Returns (return_code, stdout, stderr).
56 """
57 start = time()
58 chan = self._ssh.get_transport().open_session()
59 if timeout is not None:
60 chan.settimeout(int(timeout))
61 chan.exec_command(cmd)
62 end = time()
63 logger.trace('exec_command "{0}" on {1} took {2} seconds'.format(cmd,
64 self._hostname, end-start))
65
66
67 stdout = ""
68 while True:
69 buf = chan.recv(self.__MAX_RECV_BUF)
70 stdout += buf
71 if not buf:
72 break
73
74 stderr = ""
75 while True:
76 buf = chan.recv_stderr(self.__MAX_RECV_BUF)
77 stderr += buf
78 if not buf:
79 break
80
81 return_code = chan.recv_exit_status()
82 logger.trace('chan_recv/_stderr took {} seconds'.format(time()-end))
83
84 return (return_code, stdout, stderr)
85
86 def scp(self, local_path, remote_path):
87 """Copy files from local_path to remote_path.
88
89 connect() method has to be called first!
90 """
91 logger.trace('SCP {0} to {1}:{2}'.format(
92 local_path, self._hostname, remote_path))
93 # SCPCLient takes a paramiko transport as its only argument
94 scp = SCPClient(self._ssh.get_transport())
95 start = time()
96 scp.put(local_path, remote_path)
97 scp.close()
98 end = time()
99 logger.trace('SCP took {0} seconds'.format(end-start))
100
101def exec_cmd(node, cmd, timeout=None):
102 """Convenience function to ssh/exec/return rc & out.
103
104 Returns (rc, stdout).
105 """
106 if node is None:
107 raise TypeError('Node parameter is None')
108 if cmd is None:
109 raise TypeError('Command parameter is None')
110 if len(cmd) == 0:
111 raise ValueError('Empty command parameter')
112
113 ssh = SSH()
114 try:
115 ssh.connect(node)
116 except Exception, e:
117 logger.error("Failed to connect to node" + e)
118 return None
119
120 try:
121 (ret_code, stdout, stderr) = ssh.exec_command(cmd, timeout=timeout)
122 except Exception, e:
123 logger.error(e)
124 return None
125
126 return (ret_code, stdout, stderr)
127