# Module with functions related with input/output operations.
#
# Author: Fernando García Gutiérrez
# Email: ga.gu.fernando.concat@gmail.com
#
import os
import json
import joblib
import pickle
import gzip
import torch
import pandas as pd
from datetime import datetime
from pathlib import Path
from . import login as base_login
from ..util.validation import (
checkMultiInputTypes,
checkInputType,
fileExists,
pathExists
)
# TODO. Implement a custom memory-efficient backend
# available backends used for saving and loading Python objects
_AVAILABLE_SERIALIZATION_BACKENDS = ['joblib', 'pickle', 'joblib_gzip', 'pickle_gzip']
_DEFAULT_BACKEND = 'joblib_gzip'
[docs]def saveJson(data: dict, file: str):
""" Saves the input dictionary into a json file.
Parameters
----------
data : dict
Dictionary to be exported to a json file.
file : str
Output json file
IMPORTANT NOTE: numpy types must be previously converted to Python types.
"""
checkMultiInputTypes(
('data', data, [dict]),
('file', file, [str]))
fileExists(file, False) # avoid overwrite existing files
with open(file, 'w') as f:
json.dump(data, f)
[docs]def loadJson(file: str) -> dict:
""" Load a json file.
Parameters
----------
file : str
Json file to be loaded.
Returns
-------
content : dict
Json file content.
"""
checkInputType('file', file, [str])
fileExists(file, True) # the file must previously exist
with open(file) as f:
content = json.load(f)
return content
[docs]def serialize(obj, path: str, time_prefix: bool = False, overwrite: bool = False,
backend: str = _DEFAULT_BACKEND) -> str:
""" Function used to serialize Python objects.
Parameters
----------
obj : object
Object to be saved.
path : str
File used to save the provided object.
time_prefix : bool, default=False
Parameter indicating whether to add a time prefix to the exported file (YYYYMMDD-HHMMSS).
overwrite : bool, default=False
Parameter indicating whether to overwrite a possible existing file.
backend : str, default='joblib'
Backend used for serialize the object.
Returns
-------
path : str
Serialized object.
"""
checkMultiInputTypes(
('path', path, [str]),
('time_prefix', time_prefix, [bool]),
('overwrite', overwrite, [bool]))
# separate path and file name
path_to_obj, filename = os.path.split(path)
path_to_obj = os.path.abspath(path_to_obj)
# add time prefix
if time_prefix:
filename = '%s-%s' % (datetime.now().strftime('%y%m%d-%H%M%S'), filename)
# create the path to the output file
file_fp = os.path.join(path_to_obj, filename)
# the input path must previously exist
pathExists(path_to_obj, must_exists=True)
if not overwrite:
fileExists(file_fp, must_exists=False)
# export the object
return _serialize(obj, file_fp, backend)
[docs]def saveTorchModel(
base_path: str,
key: str,
model: torch.nn.Module
) -> str:
""" Function used to save the weights of `torch.nn.Module` models.
Parameters
----------
base_path : str
Base directory where the model will be stored. If this directory does
not exist, it will be created.
key : str
Key used to identify the model.
model : torch.nn.Module
Model whose parameters will be saved.
Returns
-------
file : str
Generated file.
"""
# create the directory if it does not exist
if not os.path.exists(base_path):
Path(base_path).mkdir(parents=True)
output_file = os.path.join(
base_path, '%s_%s' % (
datetime.now().strftime('%Y%m%d_%H%M%S'),
key
))
with torch.no_grad():
torch.save(
model.state_dict(),
output_file
)
# clear cuda cache
torch.cuda.empty_cache()
return output_file
[docs]def saveTorchModelAndHistory(
base_path: str,
key: str,
model: torch.nn.Module,
history: dict):
""" Subroutine used to serialize model data and convergence history.
Parameters
----------
base_path : str
Base directory where the model and convergence information will be stored.
If this directory does not exist, it will be created.
key : str
Key used to identify the model.
model : torch.nn.Module
Model whose parameters will be saved.
history : dict
Dictionary similar to the one returned by the function :meth:`util.torch_util.fit_neural_network`.
"""
# save the model
model_file = saveTorchModel(
base_path=base_path,
key=key,
model=model
)
# save the convergence information
pd.concat(history).to_parquet('%s_history.parquet' % model_file)
def _serialize(obj, path: str, backend: str) -> str:
""" Subroutine used to serialize objects. """
# check input types
checkMultiInputTypes(
('backend', backend, [str]),
('path', path, [str]))
# check backends
if backend not in _AVAILABLE_SERIALIZATION_BACKENDS:
raise TypeError('Unrecognized backend "%s". Available backends are: %r' % (
backend, _AVAILABLE_SERIALIZATION_BACKENDS))
if backend == 'joblib':
out = _serializeJoblib(obj, path)
elif backend == 'joblib_gzip':
out = _gzip(_serializeJoblib(obj, path))
elif backend == 'pickle':
out = _serializePickle(obj, path)
elif backend == 'pickle_gzip':
out = _gzip(_serializePickle(obj, path))
else:
assert False, 'Unhandled case'
return out
def _serializeJoblib(obj, path) -> str:
""" Joblib serialization backend. """
joblib.dump(obj, path)
return path
def _serializePickle(obj, path) -> str:
""" Pickle serialization backend. """
with open(path, 'wb') as f:
pickle.dump(obj, f)
return path
def _gzip(in_path: str) -> str:
""" Apply a gzip compression. """
with open(in_path, 'rb') as f:
with gzip.open(in_path + '.gz', 'wb') as fgz:
fgz.writelines(f)
os.remove(in_path) # remove uncompressed file
return in_path + '.gz'
[docs]def load(file: str, backend: str = _DEFAULT_BACKEND) -> object:
""" Function used to load serialized Python objects (see :py:mod:`gojo.util.io.serialize`).
Parameters
----------
file : str
Object to be loaded.
backend : str, default='joblib'
Backend used for serialize the object.
Returns
-------
obj : object
Loaded object.
"""
checkMultiInputTypes(
('file', file, [str]),
('backend', backend, [str]))
# check backends
if backend not in _AVAILABLE_SERIALIZATION_BACKENDS:
raise TypeError('Unrecognized backend "%s". Available backends are: %r' % (
backend, _AVAILABLE_SERIALIZATION_BACKENDS))
# check that the input file exists
fileExists(file, must_exists=True)
# load the object
if backend == 'joblib':
obj = _loadJoblib(file)
elif backend == 'joblib_gzip':
obj = _loadJoblibGzip(file)
elif backend == 'pickle':
obj = _loadPickle(file)
elif backend == 'pickle_gzip':
obj = _loadPickleGzip(file)
else:
assert False, 'Unhandled case'
return obj
[docs]def pprint(*args, verbose: bool = True, level: str = None, sep: str = ' '):
""" Print function for the :py:mod:`gojo` module. """
if verbose:
if base_login.isActive():
level = level.lower() if level is not None else level
if level not in base_login.Login.logger_levels:
raise TypeError(
'Input level "{}" not found. Available levels are: {}'.format(
level, base_login.Login.logger_levels))
base_login.Login.logger_levels[level](sep.join([str(arg) for arg in args]))
else:
print(*args)
def _loadJoblib(path: str) -> object:
""" Load a joblib serialized object. """
return joblib.load(path)
def _loadJoblibGzip(path: str) -> object:
""" Load a joblib + gzip serialized object. """
with gzip.open(path, 'rb') as fgz:
obj = joblib.load(fgz)
return obj
def _loadPickle(path: str) -> object:
""" Load a pickle serialized object. """
with open(path, 'rb') as f:
obj = pickle.load(f)
return obj
def _loadPickleGzip(path: str) -> object:
""" Load a joblib + gzip serialized object. """
with gzip.open(path, 'rb') as fgz:
obj = pickle.load(fgz)
return obj
def _createObjectRepresentation(class_name: str, **parameters) -> str:
""" Function used to create object representation for the __repr__() method. """
checkInputType('class_name', class_name, [str])
representation = '{}('.format(class_name)
if not isinstance(parameters, dict):
representation += ')'
else:
for k, v in parameters.items():
representation += '\n {}={},'.format(k, v)
representation = representation.rstrip(',')
representation += '\n)'
return representation