Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,8 @@ export DATABRICKS_TOKEN=<TOKEN>
Windows PowerShell

``` cmd
$env DATABRICKS_HOST="HOST"
$env DATABRICKS_TOKEN="TOKEN"
$env:DATABRICKS_HOST="HOST"
$env:DATABRICKS_TOKEN="TOKEN"
```

__Note:__ For more information about personal access tokens review [Databricks API Authentication](https://docs.azuredatabricks.net/dev-tools/api/latest/authentication.html).
Expand Down Expand Up @@ -315,6 +315,7 @@ FLAGS
--tags_report Create a CSV report from the test results that includes the test cases tags.
--max_parallel_tests Sets the level of parallelism for test notebook execution.
--recursive Executes all tests in the hierarchical folder structure.
--poll_wait_time Polling interval duration for notebook status. Default is 5 (5 seconds).
```

__Note:__ You can also use flags syntax for POSITIONAL ARGUMENTS
Expand Down
13 changes: 7 additions & 6 deletions cli/nuttercli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
import datetime

import common.api as api
from common.apiclient import InvalidConfigurationException
from common.apiclient import DEFAULT_POLL_WAIT_TIME, InvalidConfigurationException

import common.resultsview as view
from .eventhandlers import ConsoleEventHandler
from .resultsvalidator import ExecutionResultsValidator
from .reportsman import ReportWriters
from . import reportsman as reports

__version__ = '0.1.33'
__version__ = '0.1.34'

BUILD_NUMBER_ENV_VAR = 'NUTTER_BUILD_NUMBER'

Expand Down Expand Up @@ -53,7 +53,7 @@ def __init__(self, debug=False, log_to_file=False, version=False):
def run(self, test_pattern, cluster_id,
timeout=120, junit_report=False,
tags_report=False, max_parallel_tests=1,
recursive=False):
recursive=False, poll_wait_time=DEFAULT_POLL_WAIT_TIME):
try:
logging.debug(""" Running tests. test_pattern: {} cluster_id: {} timeout: {}
junit_report: {} max_parallel_tests: {}
Expand All @@ -67,14 +67,15 @@ def run(self, test_pattern, cluster_id,
if self._is_a_test_pattern(test_pattern):
logging.debug('Executing pattern')
results = self._nutter.run_tests(
test_pattern, cluster_id, timeout, max_parallel_tests, recursive)
test_pattern, cluster_id, timeout,
max_parallel_tests, recursive, poll_wait_time)
self._nutter.events_processor_wait()
self._handle_results(results, junit_report, tags_report)
return

logging.debug('Executing single test')
result = self._nutter.run_test(test_pattern, cluster_id,
timeout)
timeout, poll_wait_time)

self._handle_results([result], junit_report, tags_report)

Expand Down Expand Up @@ -141,7 +142,7 @@ def _is_a_test_pattern(self, pattern):
segments = pattern.split('/')
if len(segments) > 0:
search_pattern = segments[len(segments)-1]
if search_pattern.lower().startswith('test_'):
if api.TestNotebook._is_valid_test_name(search_pattern):
return False
return True
logging.Fatal(
Expand Down
34 changes: 21 additions & 13 deletions common/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
"""

from abc import abstractmethod, ABCMeta
from common.apiclient import DEFAULT_POLL_WAIT_TIME
from . import utils
from .testresult import TestResults
from . import scheduler
from . import apiclient
Expand Down Expand Up @@ -85,19 +87,22 @@ def list_tests(self, path, recursive=False):

return tests

def run_test(self, testpath, cluster_id, timeout=120):
def run_test(self, testpath, cluster_id,
timeout=120, pull_wait_time=DEFAULT_POLL_WAIT_TIME):
self._add_status_event(NutterStatusEvents.TestExecutionRequest, testpath)
test_notebook = TestNotebook.from_path(testpath)
if test_notebook is None:
raise InvalidTestException

result = self.dbclient.execute_notebook(
test_notebook.path, cluster_id, timeout=timeout)
test_notebook.path, cluster_id,
timeout=timeout, pull_wait_time=pull_wait_time)

return result

def run_tests(self, pattern, cluster_id,
timeout=120, max_parallel_tests=1, recursive=False):
timeout=120, max_parallel_tests=1, recursive=False,
poll_wait_time=DEFAULT_POLL_WAIT_TIME):

self._add_status_event(NutterStatusEvents.TestExecutionRequest, pattern)
root, pattern_to_match = self._get_root_and_pattern(pattern)
Expand All @@ -114,7 +119,7 @@ def run_tests(self, pattern, cluster_id,
NutterStatusEvents.TestsListingFiltered, len(filtered_notebooks))

return self._schedule_and_run(
filtered_notebooks, cluster_id, max_parallel_tests, timeout)
filtered_notebooks, cluster_id, max_parallel_tests, timeout, poll_wait_time)

def events_processor_wait(self):
if self._events_processor is None:
Expand Down Expand Up @@ -163,20 +168,20 @@ def _get_root_and_pattern(self, pattern):
return root, valid_pattern

def _schedule_and_run(self, test_notebooks, cluster_id,
max_parallel_tests, timeout):
max_parallel_tests, timeout, pull_wait_time):
func_scheduler = scheduler.get_scheduler(max_parallel_tests)
for test_notebook in test_notebooks:
self._add_status_event(
NutterStatusEvents.TestScheduling, test_notebook.path)
logging.debug(
'Scheduling execution of: {}'.format(test_notebook.path))
func_scheduler.add_function(self._execute_notebook,
test_notebook.path, cluster_id, timeout)
test_notebook.path, cluster_id, timeout, pull_wait_time)
return self._run_and_await(func_scheduler)

def _execute_notebook(self, test_notebook_path, cluster_id, timeout):
def _execute_notebook(self, test_notebook_path, cluster_id, timeout, pull_wait_time):
result = self.dbclient.execute_notebook(test_notebook_path,
cluster_id, None, timeout)
cluster_id, None, timeout, pull_wait_time)
self._add_status_event(NutterStatusEvents.TestExecuted,
ExecutionResultEventData.from_execution_results(result))
logging.debug('Executed: {}'.format(test_notebook_path))
Expand Down Expand Up @@ -212,12 +217,18 @@ def __init__(self, name, path):

self.name = name
self.path = path
self.test_name = name.split("_")[1]
self.test_name = self.get_test_name(name)

def __eq__(self, obj):
is_equal = obj.name == self.name and obj.path == self.path
return isinstance(obj, TestNotebook) and is_equal

def get_test_name(self, name):
if name.lower().startswith('test_'):
return name.split("_")[1]
if name.lower().endswith('_test'):
return name.split("_")[0]

@classmethod
def from_path(cls, path):
name = cls._get_notebook_name_from_path(path)
Expand All @@ -227,10 +238,7 @@ def from_path(cls, path):

@classmethod
def _is_valid_test_name(cls, name):
if name is None:
return False

return name.lower().startswith('test_')
return utils.contains_test_prefix_or_surfix(name)

@classmethod
def _get_notebook_name_from_path(cls, path):
Expand Down
18 changes: 13 additions & 5 deletions common/apiclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from .httpretrier import HTTPRetrier
import logging

DEFAULT_POLL_WAIT_TIME = 5
MIN_TIMEOUT = 10

def databricks_client():

Expand All @@ -25,7 +27,7 @@ class DatabricksAPIClient(object):

def __init__(self):
config = cfg.get_auth_config()
self.min_timeout = 10
self.min_timeout = MIN_TIMEOUT

if config is None:
raise InvalidConfigurationException
Expand Down Expand Up @@ -55,7 +57,8 @@ def list_objects(self, path):
return workspace_path_obj

def execute_notebook(self, notebook_path, cluster_id,
notebook_params=None, timeout=120):
notebook_params=None, timeout=120,
pull_wait_time=DEFAULT_POLL_WAIT_TIME):
if not notebook_path:
raise ValueError("empty path")
if not cluster_id:
Expand All @@ -66,6 +69,8 @@ def execute_notebook(self, notebook_path, cluster_id,
if notebook_params is not None:
if not isinstance(notebook_params, dict):
raise ValueError("Parameters must be a dictionary")
if pull_wait_time <= 1:
pull_wait_time = DEFAULT_POLL_WAIT_TIME

name = str(uuid.uuid1())
ntask = self.__get_notebook_task(notebook_path, notebook_params)
Expand All @@ -80,11 +85,11 @@ def execute_notebook(self, notebook_path, cluster_id,
raise NotebookTaskRunIDMissingException

life_cycle_state, output = self.__pull_for_output(
runid['run_id'], timeout)
runid['run_id'], timeout, pull_wait_time)

return ExecuteNotebookResult.from_job_output(output)

def __pull_for_output(self, run_id, timeout):
def __pull_for_output(self, run_id, timeout, pull_wait_time):
timedout = time.time() + timeout
output = {}
while time.time() < timedout:
Expand All @@ -99,8 +104,11 @@ def __pull_for_output(self, run_id, timeout):
# https://docs.azuredatabricks.net/api/latest/jobs.html#jobsrunlifecyclestate
# All these are terminal states
if lcs == 'TERMINATED' or lcs == 'SKIPPED' or lcs == 'INTERNAL_ERROR':
logging.debug('Terminal state returned. {}'.format(lcs))
return lcs, output
time.sleep(1)
logging.debug('Not terminal state returned. Sleeping {}s'.format(pull_wait_time))
time.sleep(pull_wait_time)

self._raise_timeout(output)

def _raise_timeout(self, output):
Expand Down
8 changes: 1 addition & 7 deletions common/apiclientresults.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,7 @@ def _get_notebook_name_from_path(self, path):

@property
def is_test_notebook(self):
return self._is_valid_test_name(self.name)

def _is_valid_test_name(self, name):
if name is None:
return False

return name.lower().startswith('test_')
return utils.contains_test_prefix_or_surfix(self.name)


class Directory(WorkspaceObject):
Expand Down
8 changes: 7 additions & 1 deletion common/httpretrier.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,16 @@ def execute(self, function, *args, **kwargs):
raise
if isinstance(exc.response.status_code, int):
if exc.response.status_code < 500:
raise
if not self._is_invalid_state_response(exc.response):
raise
if retry:
logging.debug(
'Retrying in {0}s, {1} of {2} retries'
.format(str(waitfor), str(self._tries+1), str(self._max_retries)))
sleep(waitfor)
self._tries = self._tries + 1

def _is_invalid_state_response(self, response):
if response.status_code == 400:
return 'INVALID_STATE' in response.text
return False
6 changes: 6 additions & 0 deletions common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,9 @@ def recursive_find(dict_instance, keys):
if len(keys) == 1:
return value
return recursive_find(value, keys[1:len(keys)])

def contains_test_prefix_or_surfix(name):
if name is None:
return False

return name.lower().startswith('test_') or name.lower().endswith('_test')
13 changes: 12 additions & 1 deletion runtime/nutterfixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ def execute_tests(self):
return TestExecResults(self.test_results)

def __load_fixture(self):
if hasattr(self, 'data_loader') is False:
msg = """ If you have an __init__ method in your test class.
Make sure you make a call to initialize the parent class.
For example: super().__init__() """
raise InitializationException(msg)

test_case_dict = self.data_loader.load_fixture(self)
if test_case_dict is None:
logging.fatal("Invalid Test Fixture")
Expand All @@ -72,4 +78,9 @@ def __has_method(self, method_name):


class InvalidTestFixtureException(Exception):
pass
def __init__(self, message):
super().__init__(message)

class InitializationException(Exception):
def __init__(self, message):
super().__init__(message)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ def parse_requirements(filename):
"Operating System :: OS Independent",
],
python_requires='>=3.5.2',
)
)
18 changes: 18 additions & 0 deletions tests/databricks/test_httpretrier.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
from common.httpretrier import HTTPRetrier
import requests
import io
from requests.exceptions import HTTPError
from databricks_api import DatabricksAPI

Expand Down Expand Up @@ -65,6 +66,23 @@ def test__execute__raises_500_http_exception__retries_twice_and_raises(mocker):
return_value = retrier.execute(db.jobs.get_run_output, 1)
assert retrier._tries == 2

def test__execute__raises_invalid_state_http_exception__retries_twice_and_raises(mocker):
retrier = HTTPRetrier(2,1)

db = DatabricksAPI(host='HOST',token='TOKEN')
mock_request = mocker.patch.object(db.client.session, 'request')
response_body = " { 'error_code': 'INVALID_STATE', 'message': 'Run result is empty. " + \
" There may have been issues while saving or reading results.'} "

mock_resp = requests.models.Response()
mock_resp.status_code = 400
mock_resp.raw = io.BytesIO(bytes(response_body, 'utf-8'))
mock_request.return_value = mock_resp

with pytest.raises(HTTPError):
return_value = retrier.execute(db.jobs.get_run_output, 1)
assert retrier._tries == 2

def test__execute__raises_403_http_exception__no_retries_and_raises(mocker):
retrier = HTTPRetrier(2,1)

Expand Down
Loading