Skip to content

Commit

Permalink
completer improvements
Browse files Browse the repository at this point in the history
add support for s3 and local path completion in s3 commands.
the s3 path completion includes both bucket names and prefixes.
in case --profile/--region are passed on the command line, they are used
when performing the s3 queries.
the creation of the 'help command' was moved away from the constructor
since it takes a considerable amount of time, and is not required for path
completion.
also, support completion of --profile p<TAB> (partial profile name)
  • Loading branch information
erankor committed Dec 30, 2020
1 parent 4c71115 commit 3ebb2ea
Show file tree
Hide file tree
Showing 3 changed files with 221 additions and 11 deletions.
4 changes: 2 additions & 2 deletions awscli/clidriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def main():
return rc


def create_clidriver():
session = botocore.session.Session(EnvironmentVariables)
def create_clidriver(profile=None):
session = botocore.session.Session(EnvironmentVariables, profile=profile)
_set_user_agent_for_session(session)
load_plugins(session.full_config.get('plugins', {}),
event_hooks=session.get_component('event_emitter'))
Expand Down
137 changes: 128 additions & 9 deletions awscli/completer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,53 @@
import sys
import logging
import copy
import subprocess
from botocore.exceptions import ClientError

LOG = logging.getLogger(__name__)


class Completer(object):

def __init__(self, driver=None):
S3_SCHEME = 's3://'

def __init__(self, driver=None, region=None):
if driver is not None:
self.driver = driver
else:
self.driver = awscli.clidriver.create_clidriver()
self.main_help = self.driver.create_help_command()
self.main_options = self._get_documented_completions(
self.main_help.arg_table)
self.region = region
self.main_help = None
self.s3 = None

def complete(self, cmdline, point=None):
if point is None:
point = len(cmdline)
if point is not None:
cmdline = cmdline[0:point]

args = cmdline[0:point].split()
args = cmdline.split()
current_arg = args[-1]

if (not cmdline.endswith(current_arg)
and not current_arg.startswith('-')):
# There are spaces after the arg, treat it as a new arg
current_arg = ''
args.append('')

if len(args) > 1 and args[-2] == '--profile':
return [n for n in self.driver.session.available_profiles
if n.startswith(args[-1])]

args_set = set(args)
if 's3' in args_set and not current_arg.startswith('-'):
if len(args_set & set(['cp', 'mv', 'sync'])) > 0:
return self._complete_s3_arg(current_arg, True)
elif len(args_set & set(['ls', 'presign', 'rm'])) > 0:
return self._complete_s3_arg(current_arg, False)

cmd_args = [w for w in args if not w.startswith('-')]
opts = [w for w in args if w.startswith('-')]

self.main_help = self.driver.create_help_command()
cmd_name, cmd = self._get_command(self.main_help, cmd_args)
subcmd_name, subcmd = self._get_command(cmd, cmd_args)

Expand All @@ -61,6 +84,88 @@ def _complete_command(self, command_name, command_help, current_arg, opts):
command_help.command_table, current_arg)
return []


def _complete_s3_bucket(self, current_arg):
response = self.s3.list_buckets()
result = []
for bucket in response['Buckets']:
if bucket['Name'].startswith(current_arg):
result.append(bucket['Name'] + '/')
return result

def _complete_s3_prefix(self, current_arg):
split_arg = current_arg.split('/', 1)
if len(split_arg) < 2:
return self._complete_s3_bucket(current_arg)

bucket_name, prefix = split_arg
paginator = self.s3.get_paginator('list_objects')
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix,
Delimiter='/')

result = []
for page in page_iterator:
if 'CommonPrefixes' in page:
for item in page['CommonPrefixes']:
result.append('%s/%s' % (bucket_name, item['Prefix']))
if 'Contents' in page:
for item in page['Contents']:
result.append('%s/%s' % (bucket_name, item['Key']))

return result

def _complete_s3_prefix_loop(self, current_arg=''):
result = []
for i in range(10):
try:
cur = self._complete_s3_prefix(current_arg)
except ClientError:
break
if len(cur) == 0:
break

result = cur
if len(cur) != 1:
break
current_arg = cur[0]

return [self.S3_SCHEME + n for n in result]

def _complete_local_path(self, current_arg):
# Delegate to compgen for local file completion
cmd_args = ['bash', '-c', 'compgen -f -- %s' % current_arg]
try:
return (subprocess.check_output(cmd_args).decode('utf-8')
.splitlines())
except subprocess.CalledProcessError:
return []

def _complete_s3_arg(self, current_arg, local):
self.s3 = self.driver.session.create_client('s3',
region_name=self.region)

result = []
if current_arg.startswith(self.S3_SCHEME):
# S3 path completion
result = self._complete_s3_prefix_loop(
current_arg[len(self.S3_SCHEME):])
else:
if self.S3_SCHEME.startswith(current_arg):
# Arg is a prefix of s3 scheme - perform s3 path completion
result = self._complete_s3_prefix_loop()

if local:
# Local path completion
result += self._complete_local_path(current_arg)

if ':' in current_arg:
# Bash starts the completion from the last :, strip anything that
# precedes it
strip = current_arg.rfind(':') + 1
result = [n[strip:] for n in result]

return result

def _complete_subcommand(self, subcmd_name, subcmd_help, current_arg, opts):
if current_arg != subcmd_name and current_arg.startswith('-'):
return self._find_possible_options(current_arg, opts, subcmd_help)
Expand Down Expand Up @@ -109,7 +214,9 @@ def _get_documented_completions(self, table, startswith=None):
return names

def _find_possible_options(self, current_arg, opts, subcmd_help=None):
all_options = copy.copy(self.main_options)
main_options = self._get_documented_completions(
self.main_help.arg_table)
all_options = copy.copy(main_options)
if subcmd_help is not None:
all_options += self._get_documented_completions(
subcmd_help.arg_table)
Expand All @@ -130,7 +237,19 @@ def _find_possible_options(self, current_arg, opts, subcmd_help=None):


def complete(cmdline, point):
choices = Completer().complete(cmdline, point)
# Get the profile and region args
args = cmdline[0:point].split()

profile = None
if '--profile' in args[:-2]:
profile = args[args.index('--profile') + 1]

region = None
if '--region' in args[:-2]:
region = args[args.index('--region') + 1]

driver = awscli.clidriver.create_clidriver(profile=profile)
choices = Completer(driver, region).complete(cmdline, point)
print(' \n'.join(choices))


Expand Down
91 changes: 91 additions & 0 deletions tests/unit/test_completer.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,93 @@ def test_complete_custom_command_arguments_with_arg_already_used(self):
'--bar', '--sse'])


class TestProfileCompletion(BaseCompleterTest):
def setUp(self):
super(TestProfileCompletion, self).setUp()
commands = {
'subcommands': {},
'arguments': ['profile']
}
self.completer = Completer(
self.clidriver_creator.create_clidriver(commands,
profiles=['testprofile1', 'testprofile2', 'other']))

def test_complete_profile_empty(self):
self.assert_completion(self.completer, 's3 ls s3:// --profile ', [
'testprofile1', 'testprofile2', 'other'])

def test_complete_profile_partial_multi(self):
self.assert_completion(self.completer, 's3 ls s3:// --profile test', [
'testprofile1', 'testprofile2'])

def test_complete_profile_partial_single(self):
self.assert_completion(self.completer, 's3 ls s3:// --profile o', [
'other'])


class TestS3PathCompletion(BaseCompleterTest):
def setUp(self):
super(TestS3PathCompletion, self).setUp()
commands = {
'subcommands': {},
'arguments': []
}
clidriver = self.clidriver_creator.create_clidriver(commands)

clidriver.session.create_client.return_value.list_buckets.return_value\
= {'Buckets': [{'Name': 'mybucket'}, {'Name': 'otherbucket'}]}
clidriver.session.create_client.return_value.get_paginator.return_value\
.paginate.return_value = [{'Contents': [], 'CommonPrefixes': []}]

self.completer = Completer(clidriver)

def test_complete_path_empty(self):
self.assert_completion(self.completer, 's3 ls ', [
's3://mybucket/', 's3://otherbucket/'])

def test_complete_path_scheme_only(self):
self.assert_completion(self.completer, 's3 ls s3://', [
'//mybucket/', '//otherbucket/'])

def test_complete_path_partial_bucket1(self):
self.assert_completion(self.completer, 's3 ls s3://m', [
'//mybucket/'])

def test_complete_path_partial_bucket2(self):
self.assert_completion(self.completer, 's3 ls s3://o', [
'//otherbucket/'])

def test_complete_path_prefix(self):
clidriver = self.completer.driver
clidriver.session.create_client.return_value.get_paginator.return_value\
.paginate.return_value = [{
'Contents': [
{'Key': 'key1'},
{'Key': 'key2'}
],
'CommonPrefixes': [
{'Prefix': 'prefix1'},
{'Prefix': 'prefix2'}
]
}]

self.assert_completion(self.completer, 's3 ls s3://mybucket/', [
'//mybucket/key1',
'//mybucket/key2',
'//mybucket/prefix1',
'//mybucket/prefix2'
])

def test_complete_local_path(self):
check_output_patch = mock.patch('subprocess.check_output')
check_output_patch_mock = check_output_patch.start()
check_output_patch_mock.return_value = b'file1\nfile2\n'

self.assert_completion(self.completer, 's3 cp ', [
'file1', 'file2', 's3://mybucket/', 's3://otherbucket/'])

check_output_patch.stop()

class MockCLIDriverFactory(object):
def create_clidriver(self, commands=None, profiles=None):
session = mock.Mock()
Expand Down Expand Up @@ -408,3 +495,7 @@ def create_argument_table(self, arguments):
else:
argument_table[arg] = BaseCLIArgument(arg)
return argument_table


if __name__ == "__main__":
unittest.main()

0 comments on commit 3ebb2ea

Please sign in to comment.