Skip to content

Commit

Permalink
completer improvements
Browse files Browse the repository at this point in the history
add support for s3 and local path completion in s3 commands.
the s3 path completion includes both bucket names and prefixes.
in case --profile/--region are passed on the command line, they are used
when performing the s3 queries.
the creation of the 'help command' was moved away from the constructor
since it takes a considerable amount of time, and is not required for path
completion.
also, support completion of --profile p<TAB> (partial profile name)
  • Loading branch information
erankor committed Dec 30, 2020
1 parent 4c71115 commit e9b8e00
Show file tree
Hide file tree
Showing 3 changed files with 258 additions and 12 deletions.
4 changes: 2 additions & 2 deletions awscli/clidriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def main():
return rc


def create_clidriver():
session = botocore.session.Session(EnvironmentVariables)
def create_clidriver(profile=None):
session = botocore.session.Session(EnvironmentVariables, profile=profile)
_set_user_agent_for_session(session)
load_plugins(session.full_config.get('plugins', {}),
event_hooks=session.get_component('event_emitter'))
Expand Down
137 changes: 128 additions & 9 deletions awscli/completer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,53 @@
import sys
import logging
import copy
import subprocess
from botocore.exceptions import ClientError

LOG = logging.getLogger(__name__)


class Completer(object):

def __init__(self, driver=None):
S3_SCHEME = 's3://'

def __init__(self, driver=None, region=None):
if driver is not None:
self.driver = driver
else:
self.driver = awscli.clidriver.create_clidriver()
self.main_help = self.driver.create_help_command()
self.main_options = self._get_documented_completions(
self.main_help.arg_table)
self.region = region
self.main_help = None
self.s3 = None

def complete(self, cmdline, point=None):
if point is None:
point = len(cmdline)
if point is not None:
cmdline = cmdline[0:point]

args = cmdline[0:point].split()
args = cmdline.split()
current_arg = args[-1]

if (not cmdline.endswith(current_arg)
and not current_arg.startswith('-')):
# There are spaces after the arg, treat it as a new arg
current_arg = ''
args.append('')

if len(args) > 1 and args[-2] == '--profile':
return [n for n in self.driver.session.available_profiles
if n.startswith(args[-1])]

args_set = set(args)
if 's3' in args_set and not current_arg.startswith('-'):
if len(args_set & set(['cp', 'mv', 'sync'])) > 0:
return self._complete_s3_arg(current_arg, True)
elif len(args_set & set(['ls', 'presign', 'rm'])) > 0:
return self._complete_s3_arg(current_arg, False)

cmd_args = [w for w in args if not w.startswith('-')]
opts = [w for w in args if w.startswith('-')]

self.main_help = self.driver.create_help_command()
cmd_name, cmd = self._get_command(self.main_help, cmd_args)
subcmd_name, subcmd = self._get_command(cmd, cmd_args)

Expand All @@ -61,6 +84,88 @@ def _complete_command(self, command_name, command_help, current_arg, opts):
command_help.command_table, current_arg)
return []


def _complete_s3_bucket(self, current_arg):
response = self.s3.list_buckets()
result = []
for bucket in response['Buckets']:
if bucket['Name'].startswith(current_arg):
result.append(bucket['Name'] + '/')
return result

def _complete_s3_prefix(self, current_arg):
split_arg = current_arg.split('/', 1)
if len(split_arg) < 2:
return self._complete_s3_bucket(current_arg)

bucket_name, prefix = split_arg
paginator = self.s3.get_paginator('list_objects')
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix,
Delimiter='/')

result = []
for page in page_iterator:
if 'CommonPrefixes' in page:
for item in page['CommonPrefixes']:
result.append('%s/%s' % (bucket_name, item['Prefix']))
if 'Contents' in page:
for item in page['Contents']:
result.append('%s/%s' % (bucket_name, item['Key']))

return result

def _complete_s3_prefix_loop(self, current_arg=''):
result = []
for i in range(10):
try:
cur = self._complete_s3_prefix(current_arg)
except ClientError:
break
if len(cur) == 0:
break

result = cur
if len(cur) != 1:
break
current_arg = cur[0]

return [self.S3_SCHEME + n for n in result]

def _complete_local_path(self, current_arg):
# Delegate to compgen for local file completion
cmd_args = ['bash', '-c', 'compgen -f -- %s' % current_arg]
try:
return (subprocess.check_output(cmd_args).decode('utf-8')
.splitlines())
except subprocess.CalledProcessError:
return []

def _complete_s3_arg(self, current_arg, local):
self.s3 = self.driver.session.create_client('s3',
region_name=self.region)

result = []
if current_arg.startswith(self.S3_SCHEME):
# S3 path completion
result = self._complete_s3_prefix_loop(
current_arg[len(self.S3_SCHEME):])
else:
if self.S3_SCHEME.startswith(current_arg):
# Arg is a prefix of s3 scheme - perform s3 path completion
result = self._complete_s3_prefix_loop()

if local:
# Local path completion
result += self._complete_local_path(current_arg)

if ':' in current_arg:
# Bash starts the completion from the last :, strip anything that
# precedes it
strip = current_arg.rfind(':') + 1
result = [n[strip:] for n in result]

return result

def _complete_subcommand(self, subcmd_name, subcmd_help, current_arg, opts):
if current_arg != subcmd_name and current_arg.startswith('-'):
return self._find_possible_options(current_arg, opts, subcmd_help)
Expand Down Expand Up @@ -109,7 +214,9 @@ def _get_documented_completions(self, table, startswith=None):
return names

def _find_possible_options(self, current_arg, opts, subcmd_help=None):
all_options = copy.copy(self.main_options)
main_options = self._get_documented_completions(
self.main_help.arg_table)
all_options = copy.copy(main_options)
if subcmd_help is not None:
all_options += self._get_documented_completions(
subcmd_help.arg_table)
Expand All @@ -130,7 +237,19 @@ def _find_possible_options(self, current_arg, opts, subcmd_help=None):


def complete(cmdline, point):
choices = Completer().complete(cmdline, point)
# Get the profile and region args
args = cmdline[0:point].split()

profile = None
if '--profile' in args[:-2]:
profile = args[args.index('--profile') + 1]

region = None
if '--region' in args[:-2]:
region = args[args.index('--region') + 1]

driver = awscli.clidriver.create_clidriver(profile=profile)
choices = Completer(driver, region).complete(cmdline, point)
print(' \n'.join(choices))


Expand Down
129 changes: 128 additions & 1 deletion tests/unit/test_completer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,18 @@
# language governing permissions and limitations under the License.
import pprint
import difflib
import subprocess

import mock

from botocore.compat import OrderedDict
from botocore.model import OperationModel
from botocore.exceptions import ClientError
from awscli.clidriver import (
CLIDriver, ServiceCommand, ServiceOperation, CLICommand)
from awscli.arguments import BaseCLIArgument, CustomArgument
from awscli.help import ProviderHelpCommand
from awscli.completer import Completer
from awscli.completer import Completer, complete
from awscli.testutils import unittest
from awscli.customizations.commands import BasicCommand

Expand Down Expand Up @@ -345,6 +347,127 @@ def test_complete_custom_command_arguments_with_arg_already_used(self):
'--bar', '--sse'])


class TestProfileRegionParsed(BaseCompleterTest):
def test_profile_region_used(self):
create_clidriver_patch = mock.patch('awscli.clidriver.create_clidriver')
create_clidriver_patch_mock = create_clidriver_patch.start()

cmdline = 'aws s3 --profile myprofile --region dummyregion ls '
complete(cmdline, len(cmdline))

create_clidriver_patch.stop()

create_clidriver_patch_mock.assert_called_with(profile='myprofile')
create_clidriver_patch_mock.return_value.session.create_client\
.assert_called_with('s3', region_name='dummyregion')


class TestProfileCompletion(BaseCompleterTest):
def setUp(self):
super(TestProfileCompletion, self).setUp()
commands = {
'subcommands': {},
'arguments': ['profile']
}
self.completer = Completer(
self.clidriver_creator.create_clidriver(commands,
profiles=['testprofile1', 'testprofile2', 'other']))

def test_complete_profile_empty(self):
self.assert_completion(self.completer, 's3 ls s3:// --profile ', [
'testprofile1', 'testprofile2', 'other'])

def test_complete_profile_partial_multi(self):
self.assert_completion(self.completer, 's3 ls s3:// --profile test', [
'testprofile1', 'testprofile2'])

def test_complete_profile_partial_single(self):
self.assert_completion(self.completer, 's3 ls s3:// --profile o', [
'other'])


class TestS3PathCompletion(BaseCompleterTest):
def setUp(self):
super(TestS3PathCompletion, self).setUp()
commands = {
'subcommands': {},
'arguments': []
}
clidriver = self.clidriver_creator.create_clidriver(commands)

clidriver.session.create_client.return_value.list_buckets.return_value\
= {'Buckets': [{'Name': 'mybucket'}, {'Name': 'otherbucket'}]}
clidriver.session.create_client.return_value.get_paginator.return_value\
.paginate.return_value = [{'Contents': [], 'CommonPrefixes': []}]

self.completer = Completer(clidriver)

def test_complete_path_empty(self):
self.assert_completion(self.completer, 's3 ls ', [
's3://mybucket/', 's3://otherbucket/'])

def test_complete_path_scheme_only(self):
self.assert_completion(self.completer, 's3 ls s3://', [
'//mybucket/', '//otherbucket/'])

def test_complete_path_partial_bucket1(self):
self.assert_completion(self.completer, 's3 ls s3://m', [
'//mybucket/'])

def test_complete_path_partial_bucket2(self):
self.assert_completion(self.completer, 's3 ls s3://o', [
'//otherbucket/'])

def test_complete_path_prefix(self):
clidriver = self.completer.driver
clidriver.session.create_client.return_value.get_paginator.return_value\
.paginate.return_value = [{
'Contents': [
{'Key': 'key1'},
{'Key': 'key2'}
],
'CommonPrefixes': [
{'Prefix': 'prefix1'},
{'Prefix': 'prefix2'}
]
}]

self.assert_completion(self.completer, 's3 ls s3://mybucket/', [
'//mybucket/key1',
'//mybucket/key2',
'//mybucket/prefix1',
'//mybucket/prefix2'
])

def test_complete_path_error(self):
clidriver = self.completer.driver
clidriver.session.create_client.return_value.list_buckets.side_effect\
= ClientError({
'Error': {'AccessDenied': 'NoSuchKey', 'Message': 'foo'}},
'ListBuckets')
self.assert_completion(self.completer, 's3 ls ', [])

def test_complete_local_path(self):
check_output_patch = mock.patch('subprocess.check_output')
check_output_patch_mock = check_output_patch.start()
check_output_patch_mock.return_value = b'file1\nfile2\n'

self.assert_completion(self.completer, 's3 cp ', [
'file1', 'file2', 's3://mybucket/', 's3://otherbucket/'])

check_output_patch.stop()

def test_complete_local_path_error(self):
check_output_patch = mock.patch('subprocess.check_output')
check_output_patch_mock = check_output_patch.start()
check_output_patch_mock.side_effect = subprocess.CalledProcessError(
127, 'compgen')

self.assert_completion(self.completer, 's3 cp ', [
's3://mybucket/', 's3://otherbucket/'])

check_output_patch.stop()

class MockCLIDriverFactory(object):
def create_clidriver(self, commands=None, profiles=None):
session = mock.Mock()
Expand Down Expand Up @@ -408,3 +531,7 @@ def create_argument_table(self, arguments):
else:
argument_table[arg] = BaseCLIArgument(arg)
return argument_table


if __name__ == "__main__":
unittest.main()

0 comments on commit e9b8e00

Please sign in to comment.