diff --git a/Document API/tableaudocumentapi/datasource.py b/Document API/tableaudocumentapi/datasource.py index 28f735f..3a03e1e 100644 --- a/Document API/tableaudocumentapi/datasource.py +++ b/Document API/tableaudocumentapi/datasource.py @@ -18,12 +18,14 @@ class Datasource(object): # Public API. # ########################################################################### - def __init__(self, dsxml): + def __init__(self, dsxml, filename=None): """ Constructor. Default is to create datasource from xml. """ + self._filename = filename self._datasourceXML = dsxml + self._datasourceTree = ET.ElementTree(self._datasourceXML) self._name = self._datasourceXML.get('name') or self._datasourceXML.get('formatted-name') # TDS files don't have a name attribute self._version = self._datasourceXML.get('version') self._connection = Connection(self._datasourceXML.find('connection')) @@ -32,7 +34,36 @@ def __init__(self, dsxml): def from_file(cls, filename): "Initialize datasource from file (.tds)" dsxml = ET.parse(filename).getroot() - return cls(dsxml) + return cls(dsxml, filename) + + def save(self): + """ + Call finalization code and save file. + + Args: + None. + + Returns: + Nothing. + + """ + + # save the file + self._datasourceTree.write(self._filename) + + def save_as(self, new_filename): + """ + Save our file with the name provided. + + Args: + new_filename: New name for the workbook file. String. + + Returns: + Nothing. + + """ + self._datasourceTree.write(new_filename) + ########### # name diff --git a/Document API/tableaudocumentapi/workbook.py b/Document API/tableaudocumentapi/workbook.py index e2e0c75..766871b 100644 --- a/Document API/tableaudocumentapi/workbook.py +++ b/Document API/tableaudocumentapi/workbook.py @@ -90,13 +90,8 @@ def save_as(self, new_filename): """ - # We have a valid type of input file - if self._is_valid_file(new_filename): - # save the file - self._workbookTree.write(new_filename) - else: - print('Invalid file type. Must be .twb or .tds.') - raise Exception() + self._workbookTree.write(new_filename) + ########################################################################### # diff --git a/Document API/test.py b/Document API/test.py index baf95c3..7766c3b 100644 --- a/Document API/test.py +++ b/Document API/test.py @@ -78,6 +78,14 @@ def test_can_extract_connection(self): ds = Datasource.from_file(self.tds_file.name) self.assertIsInstance(ds.connection, Connection) + def test_can_save_tds(self): + original_tds = Datasource.from_file(self.tds_file.name) + original_tds.connection.dbname = 'newdb.test.tsi.lan' + original_tds.save() + + new_tds = Datasource.from_file(self.tds_file.name) + self.assertEqual(new_tds.connection.dbname, 'newdb.test.tsi.lan') + class WorkbookModelTests(unittest.TestCase):