from keras.datasets import mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# load data with keras.datasets.mnist.load_data
The code of keras.datasets.mnist
from __future__ import absolute_import # use std lib
from __future__ import division # use accurate division
from __future__ import print_function # use py3 print
from ..utils.data_utils import get_file # fetch file
import numpy as np
def load_data(path='mnist.npz'):
# load data using keras.utils.data_utils.get_file
path = get_file(path, # path is the fname
origin='https://s3.amazonaws.com/img-datasets/mnist.npz',
file_hash='8a61469f7ea1b51cbae51d4f78837e45')
f = np.load(path) # open and read
x_train, y_train = f['x_train'], f['y_train'] # assgian dict values
x_test, y_test = f['x_test'], f['y_test']
f.close() # close the file
return (x_train, y_train), (x_test, y_test) # tuple of numpy arrays
the code of keras.utils.data_utils.get_file
def get_file(fname,
origin,
untar=False,
md5_hash=None,
file_hash=None,
cache_subdir='datasets',
hash_algorithm='auto',
extract=False,
archive_format='auto',
cache_dir=None):
if cache_dir is None: # if no dir is indicated, use the default dir
cache_dir = os.path.join(os.path.expanduser('~'), '.keras') # default dir is ~/.keras
if md5_hash is not None and file_hash is None: # if no hash is assigned to file
file_hash = md5_hash
hash_algorithm = 'md5' # use md5 algorithm to decode
datadir_base = os.path.expanduser(cache_dir) # ~/.keras
if not os.access(datadir_base, os.W_OK): # if unable to write
datadir_base = os.path.join('/tmp', '.keras') # change path to /tmp/.keras
datadir = os.path.join(datadir_base, cache_subdir) # assemble paths, ~/.keras/datasets
if not os.path.exists(datadir): # if dir is not exist
os.makedirs(datadir) # mkdir
if untar: # if using untar
untar_fpath = os.path.join(datadir, fname) # ~/.keras/datasets/mnist.npz
fpath = untar_fpath + '.tar.gz' # ~/.keras/datasets/mnist.npz.tar.gz
else: # if untar is not needed
fpath = os.path.join(datadir, fname) # ~/.keras/datasets/mnist.npz
download = False # download flag
if os.path.exists(fpath): # file found; verify integrity if a hash was provided.
if file_hash is not None: # if file_hash is given
if not validate_file(fpath, file_hash, algorithm=hash_algorithm):
# if file is broken or partly missing
print('A local file was found, but it seems to be '
'incomplete or outdated because the ' + hash_algorithm +
' file hash does not match the original value of ' +
file_hash + ' so we will re-download the data.')
download = True # change download flag
else: # if file is not found
download = True # change download flag
if download: # if download flag is true
print('Downloading data from', origin) # origin is the url
class ProgressTracker(object): # plot progbar for visualization
# Maintain progbar for the lifetime of download.
# This design was chosen for Python 2.7 compatibility.
progbar = None # initial progbar
def dl_progress(count, block_size, total_size): # download progress
# count: downloaded data block number
# block_size: transported data block size
# total_size: total data size in remote server
if ProgressTracker.progbar is None: # if it is the first progbar
if total_size is -1: # if total_size is unknown
total_size = None # set total_size to None
ProgressTracker.progbar = Progbar(total_size)
# plot progbar using keras.utils.generic_utils.Progbar
else: # continue downloading
ProgressTracker.progbar.update(count * block_size)
# update using keras.utils.generic_utils.Progbar.updata
error_msg = 'URL fetch failure on {}: {} -- {}' # download failure message
try:
try:
urlretrieve(origin, fpath, dl_progress)
# download using urllib.urlretrieve
except HTTPError as e: # six.moves.urllib.error.HTTPError
raise Exception(error_msg.format(origin, e.code, e.msg))
except URLError as e: # six.moves.urllib.error.URLError
raise Exception(error_msg.format(origin, e.errno, e.reason))
except (Exception, KeyboardInterrupt): # if downloading is interrupted
if os.path.exists(fpath): # if file exist
os.remove(fpath) # delete downloaded file part
raise
ProgressTracker.progbar = None # reset progbar state
if untar:
if not os.path.exists(untar_fpath): # precaution, in case path is messed up by user
_extract_archive(fpath, datadir, archive_format='tar')
# using keras.utils.data_utils._extract_archive to extract the file.tar.gz
# from fpath to datadir
return untar_fpath
if extract:
_extract_archive(fpath, datadir, archive_format)
return fpath
# the get_file function return a complete file name (exact location)
the code of keras.utils.generic_utils.Progbar
class Progbar(object):
def __init__(self, target, width=30, verbose=1, interval=0.05,
stateful_metrics=None):
self.target = target # total steps expected, None if unknown
self.width = width # progress bar width on screen
self.verbose = verbose # 0, silent; 1, verbose; 2, semi-verbose
self.interval = interval # update time interval, in seconds
if stateful_metrics:
self.stateful_metrics = set(stateful_metrics) # not averaged metrics
else:
self.stateful_metrics = set()
self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and
sys.stdout.isatty()) or
'ipykernel' in sys.modules) # verify display backend
self._total_width = 0
self._seen_so_far = 0
self._values = collections.OrderedDict() # construct a ordered dict for recording
self._start = time.time() # progress starting time
self._last_update = 0
def update(self, current, values=None): # current = count * block_size
values = values or []
for k, v in values: # if value is not empty
if k not in self.stateful_metrics: # if k is not specified
if k not in self._values: # if k is not recorded
self._values[k] = [v * (current - self._seen_so_far),
current - self._seen_so_far]
# record difference
else: # if k is already recorded, append the difference
self._values[k][0] += v * (current - self._seen_so_far)
self._values[k][1] += (current - self._seen_so_far)
else:
# Stateful metrics output a numeric value. This representation
# means "take an average from a single value" but keeps the
# numeric formatting.
self._values[k] = [v, 1]
self._seen_so_far = current # progress seen so far
now = time.time() # current time
info = ' - %.0fs' % (now - self._start) # time interval from start to now
if self.verbose == 1:
if (now - self._last_update < self.interval and
self.target is not None and current < self.target):
# if not satisfied update condition
return
prev_total_width = self._total_width # record width for display
if self._dynamic_display:
sys.stdout.write('\b' * prev_total_width)
sys.stdout.write('\r')
else:
sys.stdout.write('\n')
if self.target is not None:
numdigits = int(np.floor(np.log10(self.target))) + 1
# number of digits for displaying
# e.g. numdigits of 1000 is 4
barstr = '%%%dd/%d [' % (numdigits, self.target)
# if self.target is 10000000, numdigits = 8
# barstr = %8d/10000000 [
bar = barstr % current
# if current = 10000
# bar = 10000/10000000 [
prog = float(current) / self.target
prog_width = int(self.width * prog) # width of progbar
if prog_width > 0:
bar += ('=' * (prog_width - 1))
if current < self.target:
bar += '>'
# if not finish, print 10000/10000000 [======>
else:
bar += '='
# if finished, print 10000/10000000 [========
bar += ('.' * (self.width - prog_width))
# fill rest width with .
bar += ']'
# end printing with ]
else: # if target size is not known
bar = '%7d/Unknown' % current
self._total_width = len(bar)
sys.stdout.write(bar) # print bar to screen
if current:
time_per_unit = (now - self._start) / current
# calculate stepwise downloading time
else:
time_per_unit = 0
# if receiving first data, set to 0
if self.target is not None and current < self.target:
eta = time_per_unit * (self.target - current)
# calculate ETA
if eta > 3600:
eta_format = ('%d:%02d:%02d' %
(eta // 3600, (eta % 3600) // 60, eta % 60))
# print with human readable format
elif eta > 60:
eta_format = '%d:%02d' % (eta // 60, eta % 60)
else:
eta_format = '%ds' % eta
info = ' - ETA: %s' % eta_format
# ETA info string
else: # if target size is not known
if time_per_unit >= 1:
info += ' %.0fs/step' % time_per_unit
# print time cost per step
elif time_per_unit >= 1e-3:
info += ' %.0fms/step' % (time_per_unit * 1e3)
else:
info += ' %.0fus/step' % (time_per_unit * 1e6)
for k in self._values:
info += ' - %s:' % k
if isinstance(self._values[k], list):
# if progress is not done in one step
avg = np.mean(
self._values[k][0] / max(1, self._values[k][1]))
if abs(avg) > 1e-3:
info += ' %.4f' % avg
# average download percent of next data block
else:
info += ' %.4e' % avg
else:
info += ' %s' % self._values[k]
self._total_width += len(info) # recount print width
if prev_total_width > self._total_width:
info += (' ' * (prev_total_width - self._total_width))
# fill blanks if width is not large enough
if self.target is not None and current >= self.target:
info += '\n'
# shift to next line if progress finished
sys.stdout.write(info)
sys.stdout.flush()
elif self.verbose == 2: # short message only at the end of process
if self.target is None or current >= self.target:
for k in self._values:
info += ' - %s:' % k
avg = np.mean(
self._values[k][0] / max(1, self._values[k][1]))
if avg > 1e-3:
info += ' %.4f' % avg
else:
info += ' %.4e' % avg
info += '\n'
sys.stdout.write(info)
sys.stdout.flush()
self._last_update = now
def add(self, n, values=None):
self.update(self._seen_so_far + n, values)
# append upata to progbar with values
Summary: