Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix - wip #10

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ Create a `config.json` file with connection details to snowflake.
"encoding": "utf-8",
"sanitize_header": false,
"skip_rows": 0,
"infer_schema": false,
"infer_schema": false
}
],
"xml_fields": [],
Expand All @@ -54,7 +54,7 @@ Create a `config.json` file with connection details to snowflake.
"gnupghome": "/your/dir/.gnupg",
"passphrase": "your_gpg_passphrase"
},
"private_key_file": "Optional_Path",
"private_key_file": "Optional_Path"
}
```
If using the decryption feature you must pass the configs shown above, including the AWS SSM parameter name for where the decryption private key is stored. In order to retrieve this parameter the runtime environment must have access to SSM through IAM environment variables (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_SESSION_TOKEN).
Expand All @@ -76,11 +76,11 @@ Create a `config.json` file with connection details to snowflake.
- `delimiter`: A one-character string delimiter used to separate fields. Default, is `,`.

The following table configuration fields are optional:
- `key_properties`: Array containing the unique keys of the table. Defaults to `['_sdc_source_file', '_sdc_source_lineno']`, representing the file name and line number. Specify an emtpy array (`[]`) to load all new files without a replication key
- `key_properties`: Array containing the unique keys of the table. Defaults to `['_sdc_source_file', '_sdc_source_lineno']`, representing the file name and line number. Specify an empty array (`[]`) to load all new files without a replication key
- `encoding`: File encoding, defaults to `utf-8`
- `sanitize_header`: Boolean, specifies whether to clean up header names so that they are more likely to be accepted by a target SQL database
- `skip_rows`: Integer, specifies the number of rows to skip at the top of the file to handle non-data content like comments or other text. Default 0.
- `infer_schema`: Boolean. If set to true (the default value), the tap will attempt to detect the data type of the fields. Otherwise all fields will be treated as strings.
- `infer_schema`: Boolean. If set to true (the default value), the tap will attempt to detect the data type of the fields. Otherwise, all fields will be treated as strings.

## Discovery mode:

Expand Down Expand Up @@ -115,7 +115,7 @@ $ tap-sftp --config config.json --catalog catalog.json --state state.json
pip install tox
```

2. To run unit tests:
1. To run unit tests:
```
tox
```
Expand Down
131 changes: 84 additions & 47 deletions tap_sftp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@

def handle_backoff(details):
LOGGER.warn(
"SSH Connection closed unexpectedly. Waiting {wait} seconds and retrying...".format(**details)
"SSH Connection closed unexpectedly. Waiting {wait} seconds and retrying...".format(
**details
)
)


class SFTPConnection():
class SFTPConnection:
def __init__(self, host, username, password=None, private_key_file=None, port=None):
self.host = host
self.username = username
Expand All @@ -45,31 +47,41 @@ def __init__(self, host, username, password=None, private_key_file=None, port=No
max_tries=7,
on_backoff=handle_backoff,
jitter=None,
factor=2)
factor=2,
)
def __connect(self):
LOGGER.info('Creating new connection to SFTP...')
LOGGER.info("Creating new connection to SFTP...")
self._attempt_connection()
LOGGER.info('Connection successful')
LOGGER.info("Connection successful")

def _attempt_connection(self):
paramiko.Transport._preferred_ciphers = ("ssh-rsa", "ecdsa-sha2-nistp256", )
client = paramiko.SSHClient()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
try:
ssh_client = paramiko.SSHClient()
ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh_client.connect(
client.connect(
hostname=self.host,
port=self.port,
username=self.username,
password=self.password,
pkey=self.key,
compress=True,
timeout=120
)
self.sftp = ssh_client.open_sftp()
except (AuthenticationException, SSHException) as ex:
LOGGER.warning('Connection attempt failed: %s', ex)
if ssh_client:
ssh_client.close()
raise
self.sftp = client.open_sftp()
except paramiko.AuthenticationException as ex:
LOGGER.error(
"AuthenticationException, please verify your credentials: %s", ex
)
except paramiko.SSHException as ex:
LOGGER.error("SSHException, could not establish SSH connection: %s", ex)
except paramiko.BadHostKeyException as ex:
LOGGER.error("BadHostKeyException, bad host key: %s", ex)
except Exception as ex:
LOGGER.error(
"Exception, failed to connect or establish SFTP session: %s", ex
)
finally:
if client:
client.close()

def close(self):
if self.sftp is not None:
Expand All @@ -79,7 +91,11 @@ def close(self):
self.decrypted_file.close()

def match_files_for_table(self, files, table_name, search_pattern):
LOGGER.info("Searching for files for table '%s', matching pattern: %s", table_name, table_pattern)
LOGGER.info(
"Searching for files for table '%s', matching pattern: %s",
table_name,
table_pattern,
)
matcher = re.compile(search_pattern)
return [f for f in files if matcher.search(f["filepath"])]

Expand All @@ -97,8 +113,8 @@ def get_files_by_prefix(self, prefix):
"""
files = []

if prefix is None or prefix == '':
prefix = '.'
if prefix is None or prefix == "":
prefix = "."

try:
result = self.sftp.listdir_attr(prefix)
Expand All @@ -108,21 +124,29 @@ def get_files_by_prefix(self, prefix):
for file_attr in result:
# NB: This only looks at the immediate level beneath the prefix directory
if self.is_directory(file_attr):
files += self.get_files_by_prefix(prefix + '/' + file_attr.filename)
files += self.get_files_by_prefix(prefix + "/" + file_attr.filename)
else:
if self.is_empty(file_attr):
continue

last_modified = file_attr.st_mtime
if last_modified is None:
LOGGER.warning("Cannot read m_time for file %s, defaulting to current epoch time",
os.path.join(prefix, file_attr.filename))
LOGGER.warning(
"Cannot read m_time for file %s, defaulting to current epoch time",
os.path.join(prefix, file_attr.filename),
)
last_modified = datetime.utcnow().timestamp()

# NB: SFTP specifies path characters to be '/'
# https://tools.ietf.org/html/draft-ietf-secsh-filexfer-13#section-6
files.append({"filepath": prefix + '/' + file_attr.filename,
"last_modified": datetime.utcfromtimestamp(last_modified).replace(tzinfo=pytz.UTC)})
files.append(
{
"filepath": prefix + "/" + file_attr.filename,
"last_modified": datetime.utcfromtimestamp(
last_modified
).replace(tzinfo=pytz.UTC),
}
)

return files

Expand All @@ -136,57 +160,70 @@ def get_files(self, prefix, search_pattern, modified_since=None):
matching_files = self.get_files_matching_pattern(files, search_pattern)

if matching_files:
LOGGER.info('Found %s files in "%s" matching "%s"', len(matching_files), prefix, search_pattern)
LOGGER.info(
'Found %s files in "%s" matching "%s"',
len(matching_files),
prefix,
search_pattern,
)
else:
LOGGER.warning('Found no files on specified SFTP server at "%s" matching "%s"', prefix, search_pattern)

LOGGER.warning(
'Found no files on specified SFTP server at "%s" matching "%s"',
prefix,
search_pattern,
)

if modified_since is not None:
LOGGER.info("Processing files modified since: %s", modified_since)
matching_files = [f for f in matching_files if f["last_modified"] > modified_since]
matching_files = [
f for f in matching_files if f["last_modified"] > modified_since
]

for f in matching_files:
LOGGER.info("Found file: %s", f['filepath'])
LOGGER.info("Found file: %s", f["filepath"])

matching_files = sorted(matching_files, key=lambda x: x["last_modified"])
return matching_files

def get_file_handle(self, f, decryption_configs=None):
""" Takes a file dict {"filepath": "...", "last_modified": "..."} and returns a handle to the file. """
"""Takes a file dict {"filepath": "...", "last_modified": "..."} and returns a handle to the file."""
with tempfile.TemporaryDirectory() as tmpdirname:
sftp_file_path = f["filepath"]
local_path = f'{tmpdirname}/{os.path.basename(sftp_file_path)}'
local_path = f"{tmpdirname}/{os.path.basename(sftp_file_path)}"
if decryption_configs:
LOGGER.info(f'Decrypting file: {sftp_file_path}')
LOGGER.info(f"Decrypting file: {sftp_file_path}")
# Getting sftp file to local, then reading it is much faster than reading it directly from the SFTP
self.sftp.get(sftp_file_path, local_path)
decrypted_path = decrypt.gpg_decrypt(
local_path,
tmpdirname,
decryption_configs.get('key'),
decryption_configs.get('gnupghome'),
decryption_configs.get('passphrase')
decryption_configs.get("key"),
decryption_configs.get("gnupghome"),
decryption_configs.get("passphrase"),
)
LOGGER.info(f'Decrypting file complete')
LOGGER.info(f"Decrypting file complete")
try:
self.decrypted_file = open(decrypted_path, 'rb')
self.decrypted_file = open(decrypted_path, "rb")
except FileNotFoundError:
raise Exception(f'Decryption of file failed: {sftp_file_path}')
raise Exception(f"Decryption of file failed: {sftp_file_path}")
return self.decrypted_file, decrypted_path
else:
self.sftp.get(sftp_file_path, local_path)
return open(local_path, 'rb')
return open(local_path, "rb")

def get_files_matching_pattern(self, files, pattern):
""" Takes a file dict {"filepath": "...", "last_modified": "..."} and a regex pattern string, and returns
files matching that pattern. """
"""Takes a file dict {"filepath": "...", "last_modified": "..."} and a regex pattern string, and returns
files matching that pattern."""
matcher = re.compile(pattern)
LOGGER.info(f"Searching for files for matching pattern: {pattern}")
return [f for f in files if matcher.search(f["filepath"])]


def connection(config):
return SFTPConnection(config['host'],
config['username'],
password=config.get('password'),
private_key_file=config.get('private_key_file'),
port=config.get('port'))
return SFTPConnection(
config["host"],
config["username"],
password=config.get("password"),
private_key_file=config.get("private_key_file"),
port=config.get("port"),
)