keras-source-1

from keras.datasets import mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# load data with keras.datasets.mnist.load_data

The code of keras.datasets.mnist

from __future__ import absolute_import # use std lib
from __future__ import division # use accurate division
from __future__ import print_function # use py3 print
from ..utils.data_utils import get_file # fetch file
import numpy as np
def load_data(path='mnist.npz'):
	# load data using keras.utils.data_utils.get_file
    path = get_file(path, # path is the fname
                    origin='https://s3.amazonaws.com/img-datasets/mnist.npz',
                    file_hash='8a61469f7ea1b51cbae51d4f78837e45')
    f = np.load(path) # open and read
    x_train, y_train = f['x_train'], f['y_train'] # assgian dict values
    x_test, y_test = f['x_test'], f['y_test']
    f.close() # close the file
    return (x_train, y_train), (x_test, y_test) # tuple of numpy arrays

the code of keras.utils.data_utils.get_file

def get_file(fname,
             origin,
             untar=False,
             md5_hash=None,
             file_hash=None,
             cache_subdir='datasets',
             hash_algorithm='auto',
             extract=False,
             archive_format='auto',
             cache_dir=None):
    if cache_dir is None: # if no dir is indicated, use the default dir
        cache_dir = os.path.join(os.path.expanduser('~'), '.keras') # default dir is ~/.keras
    if md5_hash is not None and file_hash is None: # if no hash is assigned to file
        file_hash = md5_hash 
        hash_algorithm = 'md5' # use md5 algorithm to decode
    datadir_base = os.path.expanduser(cache_dir) # ~/.keras
    if not os.access(datadir_base, os.W_OK): # if unable to write
        datadir_base = os.path.join('/tmp', '.keras') # change path to /tmp/.keras
    datadir = os.path.join(datadir_base, cache_subdir) # assemble paths, ~/.keras/datasets
    if not os.path.exists(datadir): # if dir is not exist
        os.makedirs(datadir) # mkdir
    if untar: # if using untar
        untar_fpath = os.path.join(datadir, fname) # ~/.keras/datasets/mnist.npz
        fpath = untar_fpath + '.tar.gz' # ~/.keras/datasets/mnist.npz.tar.gz
    else: # if untar is not needed
        fpath = os.path.join(datadir, fname) # ~/.keras/datasets/mnist.npz
    download = False # download flag
    if os.path.exists(fpath): # file found; verify integrity if a hash was provided.
        if file_hash is not None: # if file_hash is given
            if not validate_file(fpath, file_hash, algorithm=hash_algorithm):
            # if file is broken or partly missing
                print('A local file was found, but it seems to be '
                      'incomplete or outdated because the ' + hash_algorithm +
                      ' file hash does not match the original value of ' +
                      file_hash + ' so we will re-download the data.')
                download = True # change download flag
    else: # if file is not found
        download = True # change download flag
    if download: # if download flag is true
        print('Downloading data from', origin) # origin is the url
        class ProgressTracker(object): # plot progbar for visualization
            # Maintain progbar for the lifetime of download.
            # This design was chosen for Python 2.7 compatibility.
            progbar = None # initial progbar
        def dl_progress(count, block_size, total_size): # download progress
        # count: downloaded data block number
        # block_size: transported data block size
        # total_size: total data size in remote server
            if ProgressTracker.progbar is None: # if it is the first progbar
                if total_size is -1: # if total_size is unknown
                    total_size = None # set total_size to None
                ProgressTracker.progbar = Progbar(total_size) 
                # plot progbar using keras.utils.generic_utils.Progbar
            else: # continue downloading
                ProgressTracker.progbar.update(count * block_size)
                # update using keras.utils.generic_utils.Progbar.updata
        error_msg = 'URL fetch failure on {}: {} -- {}' # download failure message
        try:
            try:
                urlretrieve(origin, fpath, dl_progress)
                # download using urllib.urlretrieve
            except HTTPError as e: # six.moves.urllib.error.HTTPError
                raise Exception(error_msg.format(origin, e.code, e.msg))
            except URLError as e: # six.moves.urllib.error.URLError
                raise Exception(error_msg.format(origin, e.errno, e.reason))
        except (Exception, KeyboardInterrupt): # if downloading is interrupted 
            if os.path.exists(fpath): # if file exist
                os.remove(fpath) # delete downloaded file part
            raise
        ProgressTracker.progbar = None # reset progbar state
    if untar:
        if not os.path.exists(untar_fpath): # precaution, in case path is messed up by user
            _extract_archive(fpath, datadir, archive_format='tar')
            # using keras.utils.data_utils._extract_archive to extract the file.tar.gz 
            # from fpath to datadir
        return untar_fpath
    if extract:
        _extract_archive(fpath, datadir, archive_format)
    return fpath
    # the get_file function return a complete file name (exact location)

the code of keras.utils.generic_utils.Progbar

class Progbar(object):
    def __init__(self, target, width=30, verbose=1, interval=0.05,
                 stateful_metrics=None):
        self.target = target # total steps expected, None if unknown
        self.width = width # progress bar width on screen
        self.verbose = verbose # 0, silent; 1, verbose; 2, semi-verbose
        self.interval = interval # update time interval, in seconds
        if stateful_metrics:
            self.stateful_metrics = set(stateful_metrics) # not averaged metrics
        else:
            self.stateful_metrics = set()

        self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and
                                  sys.stdout.isatty()) or
                                 'ipykernel' in sys.modules) # verify display backend
        self._total_width = 0
        self._seen_so_far = 0
        self._values = collections.OrderedDict() # construct a ordered dict for recording
        self._start = time.time() # progress starting time
        self._last_update = 0
    def update(self, current, values=None): # current = count * block_size
        values = values or []
        for k, v in values: # if value is not empty
            if k not in self.stateful_metrics: # if k is not specified
                if k not in self._values: # if k is not recorded
                    self._values[k] = [v * (current - self._seen_so_far),
                                       current - self._seen_so_far]
                                       # record difference
                else: # if k is already recorded, append the difference
                    self._values[k][0] += v * (current - self._seen_so_far)
                    self._values[k][1] += (current - self._seen_so_far)
            else:
                # Stateful metrics output a numeric value.  This representation
                # means "take an average from a single value" but keeps the
                # numeric formatting.
                self._values[k] = [v, 1]
        self._seen_so_far = current # progress seen so far
        now = time.time() # current time
        info = ' - %.0fs' % (now - self._start) # time interval from start to now
        if self.verbose == 1:
            if (now - self._last_update < self.interval and
                    self.target is not None and current < self.target):
                    # if not satisfied update condition
                return
            prev_total_width = self._total_width # record width for display
            if self._dynamic_display:
                sys.stdout.write('\b' * prev_total_width)
                sys.stdout.write('\r')
            else:
                sys.stdout.write('\n')
            if self.target is not None:
                numdigits = int(np.floor(np.log10(self.target))) + 1 
                # number of digits for displaying
                # e.g. numdigits of 1000 is 4  
                barstr = '%%%dd/%d [' % (numdigits, self.target)
                # if self.target is 10000000, numdigits = 8
                # barstr = %8d/10000000 [
                bar = barstr % current
                # if current = 10000
                # bar = 10000/10000000 [
                prog = float(current) / self.target
                prog_width = int(self.width * prog) # width of progbar
                if prog_width > 0:
                    bar += ('=' * (prog_width - 1))
                    if current < self.target:
                        bar += '>' 
                        # if not finish, print 10000/10000000 [======>
                    else:
                        bar += '=' 
                        # if finished, print 10000/10000000 [========
                bar += ('.' * (self.width - prog_width))
                # fill rest width with .
                bar += ']'
                # end printing with ]
            else: # if target size is not known
                bar = '%7d/Unknown' % current
            self._total_width = len(bar)
            sys.stdout.write(bar) # print bar to screen
            if current:
                time_per_unit = (now - self._start) / current
                # calculate stepwise downloading time
            else:
                time_per_unit = 0
                # if receiving first data, set to 0
            if self.target is not None and current < self.target:
                eta = time_per_unit * (self.target - current)
                # calculate ETA
                if eta > 3600:
                    eta_format = ('%d:%02d:%02d' %
                                  (eta // 3600, (eta % 3600) // 60, eta % 60))
                                  # print with human readable format
                elif eta > 60:
                    eta_format = '%d:%02d' % (eta // 60, eta % 60)
                else:
                    eta_format = '%ds' % eta
                info = ' - ETA: %s' % eta_format
                # ETA info string
            else: # if target size is not known
                if time_per_unit >= 1:
                    info += ' %.0fs/step' % time_per_unit
                    # print time cost per step
                elif time_per_unit >= 1e-3:
                    info += ' %.0fms/step' % (time_per_unit * 1e3)
                else:
                    info += ' %.0fus/step' % (time_per_unit * 1e6)
            for k in self._values:
                info += ' - %s:' % k
                if isinstance(self._values[k], list):
                # if progress is not done in one step
                    avg = np.mean(
                        self._values[k][0] / max(1, self._values[k][1]))
                    if abs(avg) > 1e-3:
                        info += ' %.4f' % avg
                        # average download percent of next data block
                    else:
                        info += ' %.4e' % avg
                else:
                    info += ' %s' % self._values[k]
            self._total_width += len(info) # recount print width
            if prev_total_width > self._total_width:
                info += (' ' * (prev_total_width - self._total_width))
                # fill blanks if width is not large enough
            if self.target is not None and current >= self.target:
                info += '\n'
                # shift to next line if progress finished
            sys.stdout.write(info)
            sys.stdout.flush()
        elif self.verbose == 2: # short message only at the end of process
            if self.target is None or current >= self.target:
                for k in self._values:
                    info += ' - %s:' % k
                    avg = np.mean(
                        self._values[k][0] / max(1, self._values[k][1]))
                    if avg > 1e-3:
                        info += ' %.4f' % avg
                    else:
                        info += ' %.4e' % avg
                info += '\n'
                sys.stdout.write(info)
                sys.stdout.flush()
        self._last_update = now
    def add(self, n, values=None):
        self.update(self._seen_so_far + n, values)
        # append upata to progbar with values

Summary:

  1. load data with keras.utils.data_utils.get_file
  2. print progress state with keras.generic_utils.Progbar, and updata state with keras.generic_utils.Progbar.updata
  3. operations for system directory manipulation and prints to screen, good for data I/O and visualization

你可能感兴趣的:(keras,deep_learning,source_code,source_code,keras,deep_learning)