In [399]:
##Loading
In [400]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import scipy.io as sio
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
random_state=42
np.random.seed(random_state)
torch.manual_seed(random_state)
Out[400]:
<torch._C.Generator at 0x277da6e4e30>
In [401]:
mat_dict = sio.loadmat(file_name='umist_cropped.mat', appendmat=False)   # Load MATLAB file into a Python dictionary REF (scipy.io.loadmat): https://docs.scipy.org/doc/scipy/reference/generated/scipy.io.loadmat.html
print(f'mat_dict type: {type(mat_dict)}, mat_dict.keys(): {mat_dict.keys()}')
print(f'\nfacedat type: {type(mat_dict["facedat"])}, facedat shape: {mat_dict["facedat"].shape}')
print(f'dirnames type: {type(mat_dict["dirnames"])}, dirnames shape: {mat_dict["dirnames"].shape}')
face_data = mat_dict['facedat'][0]   # Load the face data from the dictionary. Produces numpy.ndarray
print(f'\nNumber of people in images [len(face_data)]: {len(face_data)}\nImage height [len(face_data[0]]: {len(face_data[0])}\nImage width [len(face_data[0][0]]: {len(face_data[0][0])}\nImage total pixels [height X width]: {len(face_data[0]) * len(face_data[0][0])}')
mat_dict type: <class 'dict'>, mat_dict.keys(): dict_keys(['__header__', '__version__', '__globals__', 'facedat', 'dirnames'])

facedat type: <class 'numpy.ndarray'>, facedat shape: (1, 20)
dirnames type: <class 'numpy.ndarray'>, dirnames shape: (1, 20)

Number of people in images [len(face_data)]: 20
Image height [len(face_data[0]]: 112
Image width [len(face_data[0][0]]: 92
Image total pixels [height X width]: 10304
In [402]:
mat_dict
Out[402]:
{'__header__': b'MATLAB 5.0 MAT-file, Platform: GLNX86, Created on: Wed Aug 28 11:38:19 2002',
 '__version__': '1.0',
 '__globals__': [],
 'facedat': array([[array([[[233, 234, 234, ..., 236, 230, 234],
                 [234, 234, 234, ..., 235, 234, 232],
                 [234, 234, 234, ..., 236, 233, 234],
                 ...,
                 [234, 234, 234, ..., 237, 232, 233],
                 [234, 234, 234, ..., 236, 233, 234],
                 [234, 234, 233, ..., 236, 234, 233]],
 
                [[234, 234, 234, ..., 236, 232, 233],
                 [234, 234, 234, ..., 234, 233, 234],
                 [234, 234, 234, ..., 235, 233, 233],
                 ...,
                 [234, 234, 234, ..., 236, 230, 234],
                 [234, 234, 234, ..., 236, 234, 233],
                 [234, 234, 234, ..., 237, 233, 230]],
 
                [[234, 233, 234, ..., 236, 232, 233],
                 [234, 234, 233, ..., 235, 230, 234],
                 [234, 234, 234, ..., 237, 233, 233],
                 ...,
                 [234, 234, 234, ..., 237, 232, 233],
                 [234, 233, 234, ..., 239, 232, 233],
                 [234, 232, 234, ..., 236, 230, 232]],
 
                ...,
 
                [[166, 190, 206, ..., 219, 215, 187],
                 [177, 202, 208, ..., 216, 215, 186],
                 [194, 204, 205, ..., 213, 218, 187],
                 ...,
                 [147, 145, 161, ..., 142, 187, 211],
                 [144, 142, 157, ..., 176, 187, 214],
                 [141, 142, 155, ..., 187, 187, 215]],
 
                [[ 92, 179, 206, ..., 216, 219, 187],
                 [113, 197, 206, ..., 215, 220, 187],
                 [143, 199, 201, ..., 215, 220, 190],
                 ...,
                 [141, 142, 157, ..., 186, 188, 215],
                 [142, 138, 152, ..., 190, 187, 216],
                 [140, 140, 154, ..., 192, 187, 216]],
 
                [[ 65, 104, 204, ..., 215, 218, 184],
                 [100, 135, 206, ..., 214, 220, 184],
                 [123, 159, 194, ..., 215, 220, 188],
                 ...,
                 [141, 138, 148, ..., 192, 187, 215],
                 [140, 137, 145, ..., 192, 187, 215],
                 [136, 137, 145, ..., 193, 185, 215]]], dtype=uint8),
         array([[[232, 233, 232, ..., 234, 234, 221],
                 [229, 232, 232, ..., 234, 234, 209],
                 [230, 230, 232, ..., 234, 234, 140],
                 ...,
                 [230, 229, 230, ...,  42,  43,  66],
                 [229, 229, 230, ...,  64,  64, 128],
                 [229, 230, 230, ...,  91, 105, 201]],
 
                [[229, 232, 232, ..., 233, 234, 207],
                 [234, 230, 230, ..., 234, 234, 174],
                 [229, 230, 229, ..., 233, 234, 102],
                 ...,
                 [229, 232, 232, ...,  41,  39,  57],
                 [229, 230, 229, ...,  59,  51,  86],
                 [229, 230, 229, ...,  87,  84, 165]],
 
                [[232, 230, 230, ..., 234, 234, 165],
                 [230, 229, 230, ..., 234, 233, 108],
                 [229, 230, 229, ..., 234, 234,  93],
                 ...,
                 [229, 229, 230, ...,  38,  38,  56],
                 [229, 230, 229, ...,  45,  44,  63],
                 [229, 230, 229, ...,  70,  57, 104]],
 
                ...,
 
                [[ 83,  85,  97, ..., 200, 200, 200],
                 [ 87,  92,  93, ..., 200, 200, 201],
                 [ 92,  97,  92, ..., 201, 197, 197],
                 ...,
                 [122, 122, 122, ..., 127, 138, 162],
                 [121, 123, 120, ..., 126, 130, 152],
                 [123, 123, 122, ..., 125, 126, 143]],
 
                [[ 80,  84, 100, ..., 198, 199, 197],
                 [ 85,  92,  99, ..., 200, 198, 197],
                 [ 85,  92,  95, ..., 200, 197, 195],
                 ...,
                 [122, 122, 125, ..., 125, 130, 158],
                 [122, 122, 122, ..., 126, 128, 149],
                 [122, 121, 123, ..., 125, 125, 138]],
 
                [[ 83,  78,  99, ..., 161, 193, 197],
                 [ 84,  84, 101, ..., 182, 197, 197],
                 [ 77,  86, 105, ..., 198, 197, 195],
                 ...,
                 [121, 122, 123, ..., 123, 127, 148],
                 [122, 122, 123, ..., 126, 125, 140],
                 [122, 122, 123, ..., 125, 123, 131]]], dtype=uint8),
         array([[[233, 230, 234, ..., 234, 234, 234],
                 [232, 209, 232, ..., 234, 234, 234],
                 [230, 113, 195, ..., 234, 234, 234],
                 ...,
                 [230, 232, 232, ..., 234, 234, 234],
                 [232, 232, 233, ..., 234, 234, 234],
                 [233, 232, 230, ..., 234, 234, 234]],
 
                [[233, 229, 230, ..., 234, 235, 234],
                 [230, 197, 229, ..., 234, 234, 234],
                 [232,  94, 194, ..., 234, 234, 234],
                 ...,
                 [230, 232, 232, ..., 234, 234, 234],
                 [232, 229, 230, ..., 234, 234, 234],
                 [230, 229, 230, ..., 234, 234, 234]],
 
                [[233, 178, 229, ..., 234, 234, 234],
                 [232, 143, 208, ..., 234, 234, 234],
                 [233,  57, 135, ..., 234, 234, 234],
                 ...,
                 [230, 232, 233, ..., 234, 233, 234],
                 [232, 230, 229, ..., 234, 234, 234],
                 [233, 229, 230, ..., 233, 234, 234]],
 
                ...,
 
                [[170, 194, 194, ..., 205, 201, 201],
                 [173, 198, 199, ..., 205, 202, 201],
                 [192, 193, 197, ..., 201, 201, 200],
                 ...,
                 [123, 114,  71, ..., 135, 127, 147],
                 [123, 123, 109, ..., 133, 127, 143],
                 [123, 126, 123, ..., 128, 126, 135]],
 
                [[168, 190, 190, ..., 201, 205, 201],
                 [170, 193, 197, ..., 204, 204, 201],
                 [190, 193, 193, ..., 200, 201, 201],
                 ...,
                 [122, 119,  78, ..., 134, 127, 141],
                 [122, 122, 112, ..., 129, 127, 138],
                 [122, 123, 123, ..., 126, 127, 131]],
 
                [[166, 188, 190, ..., 202, 201, 202],
                 [170, 197, 195, ..., 201, 201, 201],
                 [186, 192, 194, ..., 200, 201, 201],
                 ...,
                 [122, 121, 115, ..., 128, 126, 133],
                 [120, 121, 122, ..., 126, 126, 130],
                 [121, 122, 123, ..., 122, 126, 127]]], dtype=uint8),
         array([[[141, 234, 230, ..., 225, 226, 225],
                 [116, 232, 233, ..., 225, 228, 226],
                 [104, 234, 232, ..., 226, 230, 228],
                 ...,
                 [230, 233, 232, ..., 234, 234, 234],
                 [232, 232, 232, ..., 234, 234, 234],
                 [232, 233, 232, ..., 234, 234, 234]],
 
                [[129, 233, 233, ..., 225, 225, 222],
                 [111, 233, 233, ..., 225, 227, 223],
                 [100, 234, 233, ..., 229, 229, 225],
                 ...,
                 [229, 233, 232, ..., 234, 234, 234],
                 [233, 232, 230, ..., 234, 234, 234],
                 [229, 229, 232, ..., 234, 234, 234]],
 
                [[116, 232, 233, ..., 226, 225, 222],
                 [111, 233, 234, ..., 225, 225, 225],
                 [ 70, 232, 233, ..., 229, 229, 225],
                 ...,
                 [232, 233, 232, ..., 234, 234, 233],
                 [230, 230, 232, ..., 234, 234, 234],
                 [232, 230, 232, ..., 234, 233, 232]],
 
                ...,
 
                [[194, 197, 201, ..., 173, 171, 171],
                 [190, 197, 201, ..., 171, 171, 172],
                 [ 90, 197, 199, ..., 173, 173, 172],
                 ...,
                 [129, 126, 128, ..., 183, 180, 180],
                 [129, 127, 127, ..., 180, 179, 179],
                 [128, 126, 127, ..., 183, 179, 180]],
 
                [[195, 199, 200, ..., 174, 173, 171],
                 [192, 199, 200, ..., 173, 173, 172],
                 [138, 199, 197, ..., 174, 174, 171],
                 ...,
                 [128, 127, 126, ..., 184, 182, 178],
                 [129, 127, 127, ..., 183, 180, 178],
                 [130, 127, 127, ..., 183, 183, 178]],
 
                [[194, 200, 200, ..., 173, 173, 170],
                 [192, 200, 198, ..., 173, 172, 172],
                 [147, 197, 197, ..., 173, 174, 170],
                 ...,
                 [127, 127, 126, ..., 182, 182, 178],
                 [127, 127, 127, ..., 182, 182, 177],
                 [127, 128, 125, ..., 183, 183, 171]]], dtype=uint8),
         array([[[193, 193, 192, ..., 197, 193, 194],
                 [192, 193, 192, ..., 197, 192, 194],
                 [192, 193, 190, ..., 197, 193, 192],
                 ...,
                 [193, 192, 192, ..., 195, 192, 191],
                 [193, 192, 192, ..., 194, 193, 192],
                 [192, 193, 192, ..., 194, 193, 192]],
 
                [[192, 192, 193, ..., 197, 197, 195],
                 [191, 192, 191, ..., 197, 194, 195],
                 [191, 192, 192, ..., 197, 194, 193],
                 ...,
                 [191, 191, 192, ..., 192, 194, 192],
                 [192, 192, 192, ..., 192, 192, 193],
                 [192, 192, 194, ..., 192, 194, 193]],
 
                [[192, 192, 192, ..., 195, 197, 195],
                 [191, 192, 192, ..., 194, 193, 195],
                 [191, 191, 191, ..., 192, 193, 193],
                 ...,
                 [192, 192, 192, ..., 192, 194, 193],
                 [192, 192, 192, ..., 192, 193, 192],
                 [191, 192, 192, ..., 192, 195, 192]],
 
                ...,
 
                [[145, 156, 148, ..., 147, 169, 169],
                 [166, 172, 169, ..., 148, 174, 173],
                 [168, 165, 168, ..., 165, 173, 171],
                 ...,
                 [123, 122, 125, ..., 155, 155, 154],
                 [125, 123, 127, ..., 155, 155, 155],
                 [122, 122, 127, ..., 155, 155, 155]],
 
                [[143, 155, 145, ..., 144, 164, 165],
                 [164, 169, 171, ..., 148, 173, 173],
                 [166, 168, 170, ..., 164, 172, 171],
                 ...,
                 [122, 121, 125, ..., 155, 155, 155],
                 [123, 122, 127, ..., 155, 155, 152],
                 [122, 122, 126, ..., 155, 154, 154]],
 
                [[144, 154, 145, ..., 144, 164, 164],
                 [163, 169, 168, ..., 145, 173, 171],
                 [168, 166, 170, ..., 161, 171, 170],
                 ...,
                 [121, 120, 123, ..., 155, 155, 154],
                 [122, 122, 123, ..., 154, 155, 152],
                 [122, 123, 125, ..., 150, 154, 152]]], dtype=uint8),
         array([[[173, 173, 173, ..., 174, 174, 173],
                 [173, 176, 171, ..., 176, 174, 172],
                 [173, 173, 170, ..., 176, 173, 173],
                 ...,
                 [177, 179, 178, ..., 178, 179, 179],
                 [178, 179, 177, ..., 179, 178, 178],
                 [178, 178, 176, ..., 178, 179, 178]],
 
                [[173, 173, 174, ..., 173, 173, 173],
                 [172, 173, 174, ..., 173, 173, 171],
                 [172, 173, 173, ..., 173, 174, 173],
                 ...,
                 [176, 178, 179, ..., 178, 178, 177],
                 [177, 178, 178, ..., 178, 178, 177],
                 [174, 177, 178, ..., 178, 177, 177]],
 
                [[170, 173, 173, ..., 174, 173, 173],
                 [171, 173, 173, ..., 173, 173, 174],
                 [171, 171, 173, ..., 173, 173, 170],
                 ...,
                 [176, 176, 178, ..., 177, 178, 176],
                 [176, 178, 178, ..., 178, 178, 178],
                 [176, 177, 178, ..., 177, 178, 178]],
 
                ...,
 
                [[ 29,  34,  35, ...,  87,  86, 147],
                 [ 29,  34,  34, ...,  67,  66, 115],
                 [ 29,  30,  30, ...,  43,  43,  63],
                 ...,
                 [119, 119, 117, ..., 119, 119, 123],
                 [119, 121, 117, ..., 117, 116, 120],
                 [117, 119, 119, ..., 117, 117, 119]],
 
                [[ 30,  34,  35, ...,  70,  66, 120],
                 [ 29,  31,  33, ...,  55,  53,  79],
                 [ 29,  29,  29, ...,  37,  36,  52],
                 ...,
                 [117, 120, 117, ..., 117, 117, 120],
                 [117, 120, 117, ..., 117, 116, 119],
                 [119, 121, 119, ..., 116, 115, 117]],
 
                [[ 36,  30,  30, ...,  55,  52,  81],
                 [ 36,  29,  29, ...,  46,  45,  63],
                 [ 38,  30,  31, ...,  36,  36,  44],
                 ...,
                 [117, 120, 116, ..., 116, 117, 116],
                 [117, 120, 117, ..., 117, 114, 117],
                 [115, 119, 117, ..., 116, 113, 117]]], dtype=uint8),
         array([[[179, 177, 178, ..., 168, 145, 148],
                 [179, 178, 182, ..., 173, 147, 150],
                 [177, 178, 180, ..., 178, 147, 149],
                 ...,
                 [163, 162, 164, ..., 164, 164, 164],
                 [163, 161, 164, ..., 164, 164, 164],
                 [163, 161, 164, ..., 164, 164, 165]],
 
                [[178, 174, 177, ..., 166, 148, 147],
                 [178, 178, 178, ..., 176, 150, 147],
                 [173, 180, 178, ..., 178, 150, 147],
                 ...,
                 [162, 163, 162, ..., 164, 164, 163],
                 [161, 164, 161, ..., 164, 164, 163],
                 [162, 163, 162, ..., 164, 164, 164]],
 
                [[178, 174, 177, ..., 163, 145, 148],
                 [178, 178, 178, ..., 172, 148, 148],
                 [177, 180, 179, ..., 177, 147, 148],
                 ...,
                 [159, 163, 162, ..., 164, 164, 163],
                 [161, 163, 163, ..., 164, 164, 163],
                 [159, 163, 161, ..., 164, 164, 164]],
 
                ...,
 
                [[ 43,  43,  31, ...,  69,  36,  42],
                 [ 43,  43,  28, ...,  55,  31,  46],
                 [ 37,  43,  35, ...,  23,  24,  45],
                 ...,
                 [113, 113, 113, ..., 115, 114, 113],
                 [114, 112, 114, ..., 114, 115, 114],
                 [114, 113, 113, ..., 114, 113, 113]],
 
                [[ 43,  43,  30, ...,  72,  34,  37],
                 [ 43,  43,  34, ...,  60,  31,  44],
                 [ 37,  43,  42, ...,  24,  24,  44],
                 ...,
                 [113, 113, 113, ..., 113, 116, 113],
                 [113, 113, 113, ..., 113, 114, 113],
                 [113, 114, 113, ..., 113, 114, 113]],
 
                [[ 43,  49,  31, ...,  92,  35,  29],
                 [ 43,  44,  31, ...,  77,  30,  37],
                 [ 41,  43,  38, ...,  38,  23,  43],
                 ...,
                 [112, 113, 113, ..., 113, 116, 114],
                 [113, 113, 113, ..., 113, 115, 115],
                 [113, 114, 113, ..., 113, 113, 114]]], dtype=uint8),
         array([[[169, 169, 170, ..., 169, 170, 169],
                 [168, 169, 170, ..., 170, 169, 169],
                 [168, 169, 169, ..., 170, 169, 165],
                 ...,
                 [158, 157, 161, ..., 157, 158, 159],
                 [158, 158, 158, ..., 157, 159, 159],
                 [157, 159, 159, ..., 159, 157, 159]],
 
                [[170, 172, 168, ..., 170, 171, 164],
                 [172, 172, 169, ..., 169, 170, 165],
                 [168, 170, 168, ..., 170, 172, 165],
                 ...,
                 [159, 159, 158, ..., 159, 159, 156],
                 [159, 161, 157, ..., 159, 159, 158],
                 [159, 162, 158, ..., 159, 161, 158]],
 
                [[169, 169, 169, ..., 169, 170, 166],
                 [169, 169, 169, ..., 169, 169, 165],
                 [166, 168, 169, ..., 169, 168, 165],
                 ...,
                 [157, 159, 159, ..., 157, 158, 159],
                 [157, 161, 159, ..., 159, 159, 158],
                 [158, 159, 161, ..., 159, 159, 159]],
 
                ...,
 
                [[ 41,  37,  53, ..., 163, 157, 159],
                 [ 39,  38,  62, ..., 157, 151, 164],
                 [ 34,  36,  57, ..., 151, 150, 164],
                 ...,
                 [ 63,  72,  76, ..., 127,  97, 136],
                 [ 91, 102, 105, ..., 122,  97, 136],
                 [113, 113, 113, ..., 121, 101, 136]],
 
                [[ 39,  44,  37, ..., 161, 163, 157],
                 [ 37,  43,  41, ..., 159, 161, 163],
                 [ 34,  37,  53, ..., 159, 158, 163],
                 ...,
                 [ 66,  90,  98, ..., 127, 129, 136],
                 [ 94, 113, 113, ..., 119, 128, 135],
                 [112, 113, 113, ..., 117, 127, 135]],
 
                [[ 44,  44,  34, ..., 161, 161, 156],
                 [ 43,  43,  34, ..., 161, 159, 159],
                 [ 38,  38,  41, ..., 159, 158, 159],
                 ...,
                 [ 80,  94, 112, ..., 123, 129, 134],
                 [109, 111, 113, ..., 117, 127, 131],
                 [113, 113, 112, ..., 116, 125, 134]]], dtype=uint8),
         array([[[173, 187, 183, ..., 186, 187, 190],
                 [161, 186, 183, ..., 184, 184, 187],
                 [141, 183, 183, ..., 179, 180, 186],
                 ...,
                 [174, 173, 173, ..., 147, 178, 177],
                 [176, 173, 173, ..., 156, 178, 174],
                 [176, 173, 173, ..., 162, 178, 177]],
 
                [[166, 186, 185, ..., 184, 182, 190],
                 [156, 186, 184, ..., 180, 182, 191],
                 [134, 184, 184, ..., 177, 178, 187],
                 ...,
                 [173, 173, 174, ..., 142, 178, 178],
                 [172, 173, 177, ..., 150, 176, 178],
                 [173, 174, 176, ..., 156, 177, 178]],
 
                [[152, 185, 183, ..., 179, 182, 188],
                 [140, 183, 183, ..., 174, 177, 187],
                 [122, 184, 182, ..., 166, 172, 188],
                 ...,
                 [174, 173, 173, ..., 142, 178, 177],
                 [173, 173, 176, ..., 149, 177, 178],
                 [174, 173, 173, ..., 150, 178, 176]],
 
                ...,
 
                [[ 39,  45,  38, ...,  39,  45,  38],
                 [ 42,  45,  38, ...,  37,  46,  38],
                 [ 37,  46,  37, ...,  36,  48,  35],
                 ...,
                 [114, 114, 113, ...,  33,  76,  77],
                 [113, 114, 113, ...,  43,  97,  84],
                 [114, 116, 113, ...,  56, 104,  86]],
 
                [[ 38,  44,  38, ...,  37,  44,  38],
                 [ 42,  45,  35, ...,  34,  46,  38],
                 [ 35,  45,  38, ...,  33,  48,  34],
                 ...,
                 [115, 113, 113, ...,  35,  99,  76],
                 [114, 113, 114, ...,  57, 107,  84],
                 [116, 113, 114, ...,  76, 113,  86]],
 
                [[ 39,  42,  34, ...,  36,  37,  37],
                 [ 42,  44,  34, ...,  34,  41,  37],
                 [ 34,  45,  37, ...,  33,  43,  34],
                 ...,
                 [113, 113, 113, ...,  34, 101,  85],
                 [113, 113, 113, ...,  58, 107,  88],
                 [114, 114, 113, ...,  78, 112,  92]]], dtype=uint8),
         array([[[214, 214, 212, ..., 215, 206, 218],
                 [213, 213, 211, ..., 218, 207, 215],
                 [212, 213, 214, ..., 215, 204, 215],
                 ...,
                 [205, 204, 201, ..., 202, 215, 205],
                 [204, 205, 201, ..., 204, 219, 205],
                 [202, 204, 204, ..., 204, 218, 206]],
 
                [[215, 215, 214, ..., 216, 204, 215],
                 [215, 214, 214, ..., 216, 205, 214],
                 [214, 214, 214, ..., 215, 201, 215],
                 ...,
                 [206, 206, 205, ..., 206, 214, 204],
                 [206, 206, 206, ..., 206, 215, 206],
                 [206, 206, 206, ..., 206, 215, 201]],
 
                [[215, 215, 214, ..., 215, 205, 215],
                 [215, 215, 215, ..., 215, 204, 214],
                 [214, 214, 213, ..., 215, 201, 214],
                 ...,
                 [206, 206, 205, ..., 206, 215, 202],
                 [206, 206, 205, ..., 206, 216, 201],
                 [206, 206, 205, ..., 206, 215, 206]],
 
                ...,
 
                [[200, 201, 199, ..., 173, 131, 173],
                 [200, 201, 198, ..., 173, 130, 173],
                 [199, 201, 198, ..., 174, 131, 174],
                 ...,
                 [130, 131, 131, ..., 128, 174, 129],
                 [131, 131, 131, ..., 128, 176, 128],
                 [130, 131, 131, ..., 128, 172, 128]],
 
                [[201, 201, 199, ..., 172, 129, 172],
                 [201, 200, 198, ..., 171, 130, 173],
                 [201, 198, 194, ..., 172, 130, 173],
                 ...,
                 [131, 131, 130, ..., 128, 173, 129],
                 [131, 131, 129, ..., 128, 173, 129],
                 [131, 131, 129, ..., 128, 173, 128]],
 
                [[201, 201, 197, ..., 170, 129, 172],
                 [201, 201, 197, ..., 173, 128, 172],
                 [199, 198, 193, ..., 171, 130, 173],
                 ...,
                 [131, 131, 131, ..., 127, 172, 129],
                 [133, 130, 127, ..., 129, 174, 129],
                 [131, 131, 129, ..., 128, 171, 128]]], dtype=uint8),
         array([[[207, 206, 211, ..., 218, 220, 219],
                 [207, 207, 211, ..., 219, 220, 218],
                 [207, 206, 208, ..., 216, 220, 219],
                 ...,
                 [206, 202, 206, ..., 208, 208, 209],
                 [206, 204, 206, ..., 208, 211, 207],
                 [205, 205, 204, ..., 208, 209, 208]],
 
                [[207, 208, 208, ..., 220, 219, 219],
                 [206, 209, 208, ..., 219, 219, 220],
                 [206, 207, 208, ..., 220, 218, 220],
                 ...,
                 [204, 205, 202, ..., 208, 207, 209],
                 [204, 206, 204, ..., 211, 208, 211],
                 [201, 205, 202, ..., 211, 208, 211]],
 
                [[209, 209, 211, ..., 218, 220, 220],
                 [208, 208, 209, ..., 215, 220, 220],
                 [209, 206, 208, ..., 218, 218, 219],
                 ...,
                 [206, 202, 206, ..., 206, 209, 211],
                 [206, 204, 204, ..., 207, 209, 211],
                 [204, 204, 206, ..., 208, 211, 211]],
 
                ...,
 
                [[192, 195, 192, ..., 206, 208, 206],
                 [192, 192, 192, ..., 206, 206, 206],
                 [191, 192, 191, ..., 204, 206, 206],
                 ...,
                 [133, 134, 131, ..., 137, 137, 136],
                 [134, 134, 134, ..., 136, 136, 136],
                 [133, 134, 133, ..., 135, 136, 136]],
 
                [[192, 192, 193, ..., 206, 206, 206],
                 [193, 192, 192, ..., 205, 206, 205],
                 [191, 191, 192, ..., 202, 204, 206],
                 ...,
                 [133, 131, 133, ..., 136, 137, 135],
                 [131, 133, 131, ..., 136, 136, 138],
                 [131, 134, 131, ..., 136, 136, 137]],
 
                [[192, 192, 192, ..., 206, 206, 205],
                 [190, 191, 192, ..., 206, 204, 204],
                 [188, 190, 188, ..., 206, 204, 205],
                 ...,
                 [131, 131, 131, ..., 136, 136, 135],
                 [131, 131, 131, ..., 136, 136, 136],
                 [131, 131, 131, ..., 136, 136, 136]]], dtype=uint8),
         array([[[225, 225, 226, ..., 229, 229, 229],
                 [223, 225, 225, ..., 229, 229, 229],
                 [225, 225, 225, ..., 229, 229, 229],
                 ...,
                 [214, 214, 206, ..., 215, 215, 214],
                 [216, 215, 215, ..., 216, 215, 214],
                 [216, 218, 215, ..., 215, 213, 214]],
 
                [[225, 225, 225, ..., 229, 229, 229],
                 [223, 226, 225, ..., 229, 229, 229],
                 [221, 223, 225, ..., 229, 229, 229],
                 ...,
                 [215, 207, 201, ..., 215, 216, 215],
                 [219, 215, 212, ..., 218, 216, 215],
                 [220, 219, 213, ..., 215, 216, 215]],
 
                [[225, 225, 225, ..., 229, 229, 228],
                 [221, 225, 226, ..., 229, 229, 229],
                 [218, 222, 225, ..., 229, 229, 229],
                 ...,
                 [215, 201, 176, ..., 216, 215, 216],
                 [218, 213, 202, ..., 219, 216, 216],
                 [218, 215, 206, ..., 215, 215, 219]],
 
                ...,
 
                [[ 16,  16, 208, ..., 192, 204, 190],
                 [ 16,  15, 197, ..., 201, 214, 195],
                 [ 15,  15, 142, ..., 220, 218, 208],
                 ...,
                 [141, 141, 143, ..., 141, 136, 136],
                 [141, 141, 143, ..., 141, 140, 136],
                 [141, 141, 144, ..., 140, 137, 140]],
 
                [[ 15,  20, 215, ..., 192, 202, 187],
                 [ 15,  15, 212, ..., 198, 213, 193],
                 [ 15,  16, 173, ..., 215, 218, 204],
                 ...,
                 [141, 142, 142, ..., 136, 137, 136],
                 [141, 142, 142, ..., 141, 141, 136],
                 [144, 141, 143, ..., 137, 141, 136]],
 
                [[ 15,  19, 214, ..., 190, 201, 186],
                 [ 16,  15, 211, ..., 194, 213, 191],
                 [ 15,  15, 174, ..., 216, 220, 201],
                 ...,
                 [142, 142, 140, ..., 137, 140, 136],
                 [141, 142, 141, ..., 140, 141, 135],
                 [142, 142, 141, ..., 137, 140, 136]]], dtype=uint8),
         array([[[170, 168, 168, ..., 171, 170, 170],
                 [170, 168, 166, ..., 171, 168, 170],
                 [172, 170, 169, ..., 169, 170, 169],
                 ...,
                 [173, 169, 168, ..., 173, 169, 173],
                 [173, 170, 169, ..., 172, 170, 173],
                 [173, 171, 169, ..., 171, 169, 171]],
 
                [[170, 165, 169, ..., 170, 171, 169],
                 [169, 168, 168, ..., 171, 170, 168],
                 [170, 169, 169, ..., 169, 171, 169],
                 ...,
                 [171, 168, 170, ..., 173, 172, 171],
                 [172, 166, 171, ..., 170, 171, 170],
                 [172, 166, 172, ..., 170, 173, 169]],
 
                [[169, 166, 168, ..., 170, 171, 166],
                 [168, 166, 169, ..., 169, 169, 168],
                 [169, 166, 168, ..., 170, 170, 169],
                 ...,
                 [171, 165, 169, ..., 169, 173, 170],
                 [171, 166, 168, ..., 172, 171, 169],
                 [170, 169, 169, ..., 170, 172, 169]],
 
                ...,
 
                [[161, 162, 161, ..., 137, 140, 137],
                 [159, 161, 159, ..., 138, 138, 138],
                 [159, 161, 161, ..., 140, 142, 141],
                 ...,
                 [117, 117, 117, ..., 117, 117, 117],
                 [116, 116, 114, ..., 117, 117, 117],
                 [117, 116, 115, ..., 116, 117, 117]],
 
                [[159, 161, 163, ..., 137, 141, 136],
                 [158, 159, 161, ..., 138, 138, 137],
                 [158, 159, 159, ..., 140, 140, 138],
                 ...,
                 [116, 117, 116, ..., 117, 117, 117],
                 [115, 113, 117, ..., 117, 117, 117],
                 [115, 115, 117, ..., 116, 119, 117]],
 
                [[159, 159, 164, ..., 136, 140, 136],
                 [159, 159, 159, ..., 137, 138, 137],
                 [158, 159, 159, ..., 140, 141, 138],
                 ...,
                 [115, 114, 116, ..., 119, 116, 117],
                 [114, 115, 117, ..., 117, 117, 119],
                 [114, 115, 116, ..., 116, 117, 117]]], dtype=uint8),
         array([[[135, 136, 136, ..., 162, 164, 164],
                 [135, 136, 136, ..., 163, 164, 159],
                 [136, 136, 136, ..., 163, 164, 162],
                 ...,
                 [162, 163, 161, ..., 134, 136, 136],
                 [161, 163, 159, ..., 136, 134, 136],
                 [161, 164, 163, ..., 134, 136, 136]],
 
                [[136, 134, 136, ..., 164, 163, 163],
                 [136, 134, 136, ..., 164, 163, 164],
                 [136, 133, 136, ..., 163, 162, 164],
                 ...,
                 [164, 162, 161, ..., 136, 134, 136],
                 [163, 162, 163, ..., 136, 133, 136],
                 [164, 161, 164, ..., 134, 136, 136]],
 
                [[136, 135, 136, ..., 164, 162, 164],
                 [136, 134, 135, ..., 164, 161, 164],
                 [136, 133, 136, ..., 163, 162, 164],
                 ...,
                 [164, 159, 161, ..., 133, 136, 136],
                 [164, 161, 159, ..., 134, 134, 135],
                 [164, 161, 161, ..., 134, 136, 135]],
 
                ...,
 
                [[108, 108,  80, ..., 136, 131, 133],
                 [106, 108,  70, ..., 136, 133, 134],
                 [105, 109, 106, ..., 136, 134, 136],
                 ...,
                 [113, 109, 112, ...,  65,  83,  92],
                 [113, 108, 112, ...,  79,  97, 101],
                 [113, 112, 113, ...,  87, 101, 101]],
 
                [[106, 108, 101, ..., 135, 131, 136],
                 [104, 111,  95, ..., 136, 136, 136],
                 [105, 112, 105, ..., 134, 136, 136],
                 ...,
                 [112, 112, 111, ...,  80,  98, 101],
                 [112, 113, 112, ...,  98, 105, 107],
                 [109, 112, 113, ..., 102, 108, 106]],
 
                [[108, 108, 107, ..., 131, 129, 134],
                 [105, 108, 105, ..., 133, 134, 135],
                 [105, 109, 101, ..., 134, 135, 136],
                 ...,
                 [108, 111, 113, ...,  84, 104, 106],
                 [109, 111, 113, ...,  99, 107, 109],
                 [111, 111, 114, ..., 101, 109, 107]]], dtype=uint8),
         array([[[166, 127, 168, ..., 173, 171, 170],
                 [168, 149, 168, ..., 172, 172, 170],
                 [166, 128, 169, ..., 171, 172, 170],
                 ...,
                 [172,   1, 173, ..., 150, 150, 145],
                 [171,   1, 173, ..., 150, 149, 145],
                 [171,   1, 173, ..., 145, 147, 143]],
 
                [[165, 165, 164, ..., 171, 173, 172],
                 [164, 164, 165, ..., 171, 173, 173],
                 [165, 164, 166, ..., 170, 173, 173],
                 ...,
                 [170, 169, 169, ..., 148, 150, 148],
                 [169, 169, 169, ..., 150, 150, 145],
                 [169, 171, 171, ..., 145, 147, 145]],
 
                [[165, 164, 164, ..., 170, 173, 169],
                 [164, 164, 164, ..., 171, 173, 171],
                 [164, 165, 164, ..., 170, 173, 171],
                 ...,
                 [169, 170, 171, ..., 148, 150, 147],
                 [169, 169, 170, ..., 147, 150, 145],
                 [169, 170, 171, ..., 147, 147, 143]],
 
                ...,
 
                [[143, 142, 145, ...,  46,  63,  62],
                 [144, 147, 133, ...,  60,  66,  63],
                 [136, 143, 125, ...,  65,  66,  63],
                 ...,
                 [ 36,  49,  36, ..., 127, 129, 129],
                 [ 39, 104,  38, ..., 128, 130, 130],
                 [ 42, 119,  38, ..., 128, 127, 128]],
 
                [[150, 150, 136, ...,  50,  62,  62],
                 [151, 150, 138, ...,  59,  63,  63],
                 [148, 149, 130, ...,  62,  66,  62],
                 ...,
                 [ 34,  74,  35, ..., 129, 127, 127],
                 [ 38, 116,  36, ..., 127, 131, 127],
                 [ 49, 121,  39, ..., 127, 127, 127]],
 
                [[154, 152, 143, ...,  55,  59,  60],
                 [156, 151, 142, ...,  57,  64,  62],
                 [154, 155, 137, ...,  62,  66,  62],
                 ...,
                 [ 37, 112,  37, ..., 129, 129, 127],
                 [ 44, 117,  38, ..., 127, 131, 129],
                 [ 80, 122,  55, ..., 126, 128, 127]]], dtype=uint8),
         array([[[173, 169, 170, ..., 173, 173, 172],
                 [171, 170, 171, ..., 172, 171, 172],
                 [171, 170, 171, ..., 173, 171, 171],
                 ...,
                 [173, 172, 173, ..., 161, 161, 168],
                 [173, 173, 173, ..., 164, 164, 168],
                 [174, 172, 173, ..., 166, 164, 171]],
 
                [[171, 171, 169, ..., 173, 173, 173],
                 [169, 171, 169, ..., 173, 172, 173],
                 [171, 170, 169, ..., 173, 172, 170],
                 ...,
                 [173, 173, 172, ..., 151, 158, 169],
                 [172, 173, 172, ..., 166, 165, 170],
                 [173, 173, 171, ..., 168, 164, 173]],
 
                [[169, 169, 170, ..., 169, 172, 173],
                 [170, 169, 171, ..., 172, 171, 173],
                 [168, 170, 171, ..., 172, 170, 173],
                 ...,
                 [172, 173, 173, ..., 142, 155, 169],
                 [172, 173, 174, ..., 161, 163, 169],
                 [172, 173, 173, ..., 165, 163, 172]],
 
                ...,
 
                [[156, 159, 159, ..., 117, 117, 117],
                 [156, 161, 158, ..., 119, 117, 117],
                 [155, 159, 157, ..., 117, 117, 117],
                 ...,
                 [117, 117, 117, ..., 159, 163, 168],
                 [117, 120, 116, ..., 164, 164, 162],
                 [117, 117, 117, ..., 159, 164, 150]],
 
                [[158, 159, 158, ..., 117, 119, 117],
                 [156, 159, 158, ..., 117, 117, 117],
                 [155, 158, 158, ..., 117, 117, 117],
                 ...,
                 [117, 117, 117, ..., 159, 162, 165],
                 [117, 117, 117, ..., 161, 164, 157],
                 [117, 119, 117, ..., 161, 165, 147]],
 
                [[157, 158, 159, ..., 117, 116, 116],
                 [155, 158, 159, ..., 117, 114, 117],
                 [155, 156, 159, ..., 120, 116, 117],
                 ...,
                 [117, 117, 117, ..., 164, 159, 163],
                 [117, 117, 119, ..., 163, 164, 155],
                 [117, 117, 117, ..., 163, 164, 144]]], dtype=uint8),
         array([[[148, 149, 147, ..., 149,   1, 150],
                 [145, 149, 145, ..., 151,   1, 152],
                 [147, 149, 145, ..., 152,   1, 150],
                 ...,
                 [172, 173, 173, ..., 173,   1, 173],
                 [173, 173, 173, ..., 173,   1, 172],
                 [173, 172, 172, ..., 173,   1, 173]],
 
                [[144, 145, 145, ..., 150, 150, 151],
                 [145, 145, 145, ..., 155, 150, 151],
                 [147, 145, 144, ..., 154, 154, 151],
                 ...,
                 [170, 171, 173, ..., 173, 173, 174],
                 [170, 170, 173, ..., 176, 173, 173],
                 [171, 169, 169, ..., 173, 173, 173]],
 
                [[144, 147, 145, ..., 151, 150, 150],
                 [145, 145, 145, ..., 150, 150, 149],
                 [147, 147, 147, ..., 151, 155, 149],
                 ...,
                 [169, 170, 170, ..., 172, 173, 173],
                 [172, 170, 170, ..., 173, 174, 171],
                 [172, 169, 169, ..., 173, 173, 173]],
 
                ...,
 
                [[133, 131, 133, ..., 137, 136, 135],
                 [134, 134, 133, ..., 141, 136, 135],
                 [133, 133, 133, ..., 137, 136, 135],
                 ...,
                 [120,  77, 105, ..., 150, 147, 143],
                 [122,  98, 120, ..., 150, 149, 138],
                 [120, 107, 121, ..., 149, 147, 136]],
 
                [[133, 133, 131, ..., 136, 136, 136],
                 [133, 131, 131, ..., 137, 136, 135],
                 [131, 133, 133, ..., 136, 136, 134],
                 ...,
                 [117,  95, 108, ..., 148, 147, 140],
                 [119, 114, 119, ..., 150, 150, 135],
                 [120, 119, 119, ..., 145, 145, 134]],
 
                [[131, 134, 131, ..., 136, 133, 134],
                 [131, 135, 130, ..., 136, 135, 133],
                 [133, 136, 131, ..., 136, 134, 131],
                 ...,
                 [119, 114, 117, ..., 145, 145, 136],
                 [117, 117, 119, ..., 145, 147, 133],
                 [117, 120, 117, ..., 145, 145, 131]]], dtype=uint8),
         array([[[172, 156, 166, ..., 174, 173, 176],
                 [174, 173, 173, ..., 177, 173, 176],
                 [173, 172, 173, ..., 176, 174, 176],
                 ...,
                 [169, 171, 170, ..., 168, 169, 169],
                 [169, 169, 169, ..., 172, 168, 171],
                 [170, 169, 169, ..., 170, 170, 171]],
 
                [[174, 155, 166, ..., 173, 174, 174],
                 [174, 172, 174, ..., 173, 176, 176],
                 [174, 174, 174, ..., 173, 176, 173],
                 ...,
                 [171, 171, 172, ..., 168, 171, 169],
                 [172, 171, 172, ..., 169, 169, 169],
                 [170, 171, 169, ..., 169, 170, 169]],
 
                [[173, 154, 166, ..., 172, 177, 173],
                 [174, 173, 174, ..., 174, 174, 173],
                 [173, 174, 173, ..., 173, 177, 173],
                 ...,
                 [170, 168, 169, ..., 166, 169, 165],
                 [170, 169, 170, ..., 169, 169, 169],
                 [169, 168, 169, ..., 168, 171, 169]],
 
                ...,
 
                [[137, 136, 136, ..., 136, 136, 135],
                 [136, 136, 136, ..., 136, 136, 135],
                 [138, 136, 136, ..., 137, 136, 135],
                 ...,
                 [ 48,  50,  73, ..., 127, 131, 120],
                 [ 58,  66, 116, ..., 127, 127, 127],
                 [ 77,  85, 122, ..., 126, 127, 123]],
 
                [[136, 136, 136, ..., 135, 133, 135],
                 [136, 136, 136, ..., 136, 136, 134],
                 [136, 135, 136, ..., 136, 134, 134],
                 ...,
                 [ 62,  57, 115, ..., 127, 130, 116],
                 [ 98,  97, 123, ..., 125, 126, 127],
                 [115, 116, 126, ..., 123, 126, 122]],
 
                [[135, 135, 135, ..., 135, 134, 135],
                 [136, 136, 135, ..., 135, 135, 134],
                 [136, 136, 135, ..., 135, 135, 134],
                 ...,
                 [ 72,  85, 117, ..., 123, 122, 108],
                 [107, 120, 122, ..., 123, 125, 122],
                 [117, 123, 122, ..., 122, 126, 128]]], dtype=uint8),
         array([[[172, 173, 169, ..., 169, 170, 168],
                 [170, 172, 168, ..., 169, 171, 171],
                 [169, 170, 169, ..., 173, 169, 170],
                 ...,
                 [171, 170, 169, ..., 165, 163, 164],
                 [170, 170, 169, ..., 164, 164, 163],
                 [169, 169, 168, ..., 168, 164, 166]],
 
                [[171, 170, 171, ..., 168, 172, 170],
                 [169, 169, 169, ..., 169, 172, 171],
                 [169, 170, 169, ..., 169, 171, 170],
                 ...,
                 [169, 168, 171, ..., 164, 165, 166],
                 [169, 169, 169, ..., 163, 164, 165],
                 [169, 168, 169, ..., 164, 164, 169]],
 
                [[172, 171, 171, ..., 172, 168, 168],
                 [172, 172, 168, ..., 172, 169, 170],
                 [171, 171, 168, ..., 172, 169, 169],
                 ...,
                 [169, 169, 169, ..., 164, 166, 166],
                 [170, 169, 170, ..., 164, 166, 165],
                 [170, 169, 169, ..., 165, 164, 165]],
 
                ...,
 
                [[141, 135, 131, ..., 126, 126, 127],
                 [159, 143, 131, ..., 126, 126, 127],
                 [159, 162, 136, ..., 127, 126, 127],
                 ...,
                 [115, 114, 113, ..., 112, 112, 113],
                 [115, 115, 113, ..., 108, 111, 112],
                 [117, 117, 113, ..., 112, 108, 113]],
 
                [[137, 134, 133, ..., 126, 127, 125],
                 [159, 138, 131, ..., 127, 127, 127],
                 [161, 159, 136, ..., 128, 127, 127],
                 ...,
                 [115, 114, 113, ..., 113, 113, 113],
                 [113, 115, 113, ..., 113, 113, 108],
                 [114, 115, 113, ..., 112, 113, 112]],
 
                [[136, 133, 131, ..., 126, 127, 125],
                 [159, 137, 131, ..., 128, 127, 127],
                 [163, 158, 136, ..., 127, 127, 127],
                 ...,
                 [114, 113, 115, ..., 113, 112, 112],
                 [116, 114, 113, ..., 112, 113, 108],
                 [115, 113, 113, ..., 113, 113, 109]]], dtype=uint8),
         array([[[164, 164, 164, ...,  39,  49,  15],
                 [164, 164, 164, ...,  43,  49,  28],
                 [165, 164, 164, ...,  43,  45,  17],
                 ...,
                 [166, 164, 154, ..., 115, 113,  15],
                 [165, 164, 163, ..., 144, 141,  20],
                 [165, 164, 164, ..., 161, 159,  20]],
 
                [[164, 163, 163, ...,  38,  46,  15],
                 [164, 163, 163, ...,  39,  45,  28],
                 [164, 164, 161, ...,  43,  45,  24],
                 ...,
                 [164, 159, 151, ..., 115, 111,  20],
                 [165, 163, 159, ..., 144, 141,  21],
                 [164, 161, 162, ..., 159, 158,  18]],
 
                [[164, 164, 164, ...,  44,  43,  11],
                 [164, 164, 165, ...,  42,  43,  18],
                 [164, 163, 164, ...,  41,  42,  29],
                 ...,
                 [164, 157, 128, ..., 130, 108,  18],
                 [164, 164, 149, ..., 151, 137,  20],
                 [166, 164, 155, ..., 158, 157,  21]],
 
                ...,
 
                [[154, 150, 150, ..., 155, 155,  93],
                 [152, 150, 150, ..., 155, 154,  93],
                 [151, 149, 149, ..., 155, 152,  93],
                 ...,
                 [115, 117, 117, ..., 113, 113, 101],
                 [115, 116, 117, ..., 115, 113,  91],
                 [116, 115, 114, ..., 113, 113,  89]],
 
                [[152, 149, 149, ..., 155, 152,  93],
                 [150, 149, 150, ..., 155, 152,  92],
                 [150, 147, 147, ..., 154, 152,  92],
                 ...,
                 [112, 116, 116, ..., 113, 114, 104],
                 [113, 116, 115, ..., 113, 113,  85],
                 [115, 113, 114, ..., 113, 114,  81]],
 
                [[150, 148, 149, ..., 155, 155,  93],
                 [150, 147, 149, ..., 155, 154,  92],
                 [150, 145, 147, ..., 155, 154,  94],
                 ...,
                 [113, 114, 116, ..., 113, 115, 106],
                 [113, 116, 116, ..., 114, 114,  87],
                 [114, 115, 115, ..., 113, 115,  86]]], dtype=uint8)]],
       dtype=object),
 'dirnames': array([[array(['1a'], dtype='<U2'), array(['1b'], dtype='<U2'),
         array(['1c'], dtype='<U2'), array(['1d'], dtype='<U2'),
         array(['1e'], dtype='<U2'), array(['1f'], dtype='<U2'),
         array(['1g'], dtype='<U2'), array(['1h'], dtype='<U2'),
         array(['1i'], dtype='<U2'), array(['1j'], dtype='<U2'),
         array(['1k'], dtype='<U2'), array(['1l'], dtype='<U2'),
         array(['1m'], dtype='<U2'), array(['1n'], dtype='<U2'),
         array(['1o'], dtype='<U2'), array(['1p'], dtype='<U2'),
         array(['1q'], dtype='<U2'), array(['1r'], dtype='<U2'),
         array(['1s'], dtype='<U2'), array(['1t'], dtype='<U2')]],
       dtype=object)}
In [403]:
data, label = [], []
for idx, person in enumerate(face_data):
  for i in range(person.shape[2]):
    img = person[:,:,i]
    img = img.reshape(10304)
    data.append(img)
    label.append(idx)

data, label = np.array(data), np.array(label)   # LABEL: Range of label values: 0-19. When using the for-loop and print(f'Label: {label}'), each picture is tagged with the correct label per-person
print(f'data shape: {data.shape}, label shape: {label.shape}, unique people ids: {np.unique(label)}')
data shape: (575, 10304), label shape: (575,), unique people ids: [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
In [404]:
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(10, 10))
for i in range(0,20):
    fig.add_subplot(4,5,i+1)
    plt.imshow(data[i].reshape(112,92), cmap='gray')
    plt.xticks([])
    plt.yticks([])
    person_id_label = f'Person ID: {label[i]}' 
    plt.title(person_id_label).set_color('green')
No description has been provided for this image
In [405]:
##
In [406]:
x_train, x_test, y_train, y_test = train_test_split(data, label, test_size=0.2, random_state=random_state, stratify=label)
x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.25, random_state=random_state, stratify=y_train) 

print(f'x_train: {x_train.shape}\ny_train: {y_train.shape}\nx_test: {x_test.shape}\ny_test: {y_test.shape}\nx_val: {x_val.shape}\ny_val: {y_val.shape}')
x_train: (345, 10304)
y_train: (345,)
x_test: (115, 10304)
y_test: (115,)
x_val: (115, 10304)
y_val: (115,)
In [407]:
# Standard scaling
scaler = StandardScaler()
x_train = scaler.fit_transform(x_train)
x_test = scaler.transform(x_test)
x_val = scaler.transform(x_val)

y_train = torch.eye(20)[torch.LongTensor(y_train)]
y_test = torch.eye(20)[torch.LongTensor(y_test)]
y_val = torch.eye(20)[torch.LongTensor(y_val)]


print(f'x_train: {x_train.shape}\ny_train: {y_train.shape}\nx_test: {x_test.shape}\ny_test: {y_test.shape}\nx_val: {x_val.shape}\ny_val: {y_val.shape}')
x_train: (345, 10304)
y_train: torch.Size([345, 20])
x_test: (115, 10304)
y_test: torch.Size([115, 20])
x_val: (115, 10304)
y_val: torch.Size([115, 20])
In [408]:
pca = PCA(n_components=10, random_state=random_state)
x_train_pca = pca.fit_transform(x_train)
print(f'Explained variance: {sum(pca.explained_variance_ratio_)}')
x_test_pca = pca.transform(x_test)
x_val_pca = pca.transform(x_val)
Explained variance: 0.6492373520222806
In [409]:
print(f'x_train_pca: {x_train_pca.shape}\nx_test_pca: {x_test_pca.shape}\nx_val_pca: {x_val_pca.shape}')
x_train_pca: (345, 10)
x_test_pca: (115, 10)
x_val_pca: (115, 10)
In [410]:
face_recovered = pca.inverse_transform(x_train_pca)   # Face images in the stratified training set, 0 to 344
plt.gray()
plt.xticks([])
plt.yticks([])
plt.title('Sample Image after PCA').set_color('green')
plt.imshow(np.array(face_recovered[111]).reshape(112,92))
Out[410]:
<matplotlib.image.AxesImage at 0x277c53853a0>
No description has been provided for this image
In [411]:
kmeans = KMeans(n_clusters=20, random_state=random_state)
x_train_kmeans = kmeans.fit_transform(x_train_pca)
x_test_kmeans = kmeans.transform(x_test_pca)
x_val_kmeans = kmeans.transform(x_val_pca)

x_train_tensor = torch.FloatTensor(x_train_kmeans)
x_test_tensor = torch.FloatTensor(x_test_kmeans)
x_val_tensor = torch.FloatTensor(x_val_kmeans)
D:\anaconda3\Lib\site-packages\sklearn\cluster\_kmeans.py:1429: UserWarning: KMeans is known to have a memory leak on Windows with MKL, when there are less chunks than available threads. You can avoid it by setting the environment variable OMP_NUM_THREADS=2.
  warnings.warn(
In [426]:
x_train_kmeans
Out[426]:
array([[ 85.34897988,  75.84730645,  76.09119472, ...,  97.77866986,
         83.10348576,  51.85576079],
       [ 77.5241804 , 131.31719569, 105.1927116 , ..., 108.86916624,
         93.96202626, 121.07083419],
       [ 98.44347613, 104.5782617 ,  84.55857067, ...,  98.99388958,
         75.12079416,  88.7045111 ],
       ...,
       [118.08961643,  35.20255626,  94.85801187, ...,  77.82242491,
        120.16538824,  94.55763113],
       [105.803014  ,  92.04468724,  91.22768778, ..., 105.86139586,
         83.12872025,  41.19215493],
       [ 91.05638838,  99.60399459,  76.91126016, ...,  99.80034206,
         87.99432228,  91.87189211]])
In [412]:
print(f'x_train_kmeans: {x_train_kmeans.shape}\nx_test_kmeans: {x_test_kmeans.shape}\nx_val_kmeans: {x_val_kmeans.shape}')
x_train_kmeans: (345, 20)
x_test_kmeans: (115, 20)
x_val_kmeans: (115, 20)
In [413]:
print(f'x_train_kmeans: {x_train_tensor.shape}\nx_test_kmeans: {x_train_tensor.shape}\nx_val_kmeans: {x_train_tensor.shape}')
x_train_kmeans: torch.Size([345, 20])
x_test_kmeans: torch.Size([345, 20])
x_val_kmeans: torch.Size([345, 20])
In [414]:
print(f'x_train_reshaped: {x_train_reshaped.shape}\ny_train_reshaped: {y_train_reshaped.shape}\nx_test_reshaped: {x_test_reshaped.shape}\ny_test_reshaped: {y_test_reshaped.shape}\nx_val_reshaped: {x_val_reshaped.shape}\ny_val_reshaped: {y_val_reshaped.shape}')
x_train_reshaped: (345, 1, 20)
y_train_reshaped: torch.Size([345, 1, 20])
x_test_reshaped: (115, 1, 20)
y_test_reshaped: torch.Size([115, 1, 20])
x_val_reshaped: (115, 1, 20)
y_val_reshaped: torch.Size([115, 1, 20])
In [415]:
class FaceRecognitionModel(nn.Module):
    def __init__(self, input_size, num_classes, units, dropout=False):
        super(FaceRecognitionModel, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, 34),
            nn.Linear(34, units),
            nn.ReLU(),
            nn.Dropout(0.2) if dropout else nn.Identity(),
            nn.Linear(units, units),
            nn.ReLU(),
            nn.Dropout(0.2) if dropout else nn.Identity(),
            nn.Linear(units, num_classes),
            nn.Softmax(dim=1)
        )
    
    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input
        return self.layers(x)
In [416]:
def train_model(model, train_loader, val_loader, criterion, optimizer, epochs=100):
    best_val_accuracy = 0
    best_model = None

    for epoch in range(epochs):
        model.train()
        train_loss = 0
        train_correct = 0
        total_train = 0

        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            _, true_labels = torch.max(labels, 1)
            total_train += labels.size(0)
            train_correct += (predicted == true_labels).sum().item()

        # Validation
        model.eval()
        val_loss = 0
        val_correct = 0
        total_val = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                _, true_labels = torch.max(labels, 1)
                total_val += labels.size(0)
                val_correct += (predicted == true_labels).sum().item()

        train_accuracy = 100 * train_correct / total_train
        val_accuracy = 100 * val_correct / total_val

        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            best_model = model.state_dict()

        if epoch % 10 == 0:
            print(f'Epoch {epoch}: Train Loss {train_loss/len(train_loader):.4f}, '
                  f'Train Acc {train_accuracy:.2f}%, '
                  f'Val Loss {val_loss/len(val_loader):.4f}, '
                  f'Val Acc {val_accuracy:.2f}%')

    return best_model
In [417]:
def visualize_predictions(model, x_data, y_data, x_original, title):
    model.eval()
    plt.figure(figsize=(15, 10))
    
    # Predict for first 20 samples
    predictions = []
    with torch.no_grad():
        for i in range(20):
            inputs = x_data[i].unsqueeze(0)
            output = model(inputs)
            pred = torch.argmax(output).item()
            true_label = torch.argmax(y_data[i]).item()
            predictions.append(pred)

            plt.subplot(4, 5, i+1)
            # Reconstruct and show original image
            img = x_original[i].reshape(112, 92)
            plt.imshow(img, cmap='gray')
            plt.title(f'Pred: {pred}, Actual: {true_lab el}')
            plt.axis('off')
    
    plt.suptitle(title)
    plt.tight_layout()
    plt.show()
    
    return predictions
In [418]:
print(f'x_train_kmeans: {x_train_kmeans.shape}\nx_test_kmeans: {x_test_kmeans.shape}\nx_val_kmeans: {x_val_kmeans.shape}')
x_train_kmeans: (345, 20)
x_test_kmeans: (115, 20)
x_val_kmeans: (115, 20)
In [419]:
print(x_train.shape[1])
10304
In [420]:
print(x_train_tensor.shape)
torch.Size([345, 20])
In [421]:
# Create datasets and loaders
train_dataset = TensorDataset(x_train_tensor, y_train)
val_dataset = TensorDataset(x_val_tensor, y_val)
test_dataset = TensorDataset(x_test_tensor, y_test)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
test_loader = DataLoader(test_dataset, batch_size=32)

# Hyperparameters
model_params = {
    'input_size': x_train_tensor.shape[1],
    'num_classes': 20,
    'units': 128,
    'dropout': True
}

# Initialize model
model = FaceRecognitionModel(**model_params)

# Loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Train the model
best_model_state = train_model(model, train_loader, val_loader, criterion, optimizer)

# Load best model
model.load_state_dict(best_model_state)

# Evaluate on test set
model.eval()
test_correct = 0
total_test = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        _, true_labels = torch.max(labels, 1)
        total_test += labels.size(0)
        test_correct += (predicted == true_labels).sum().item()

test_accuracy = 100 * test_correct / total_test
print(f'Test Accuracy: {test_accuracy:.2f}%')
Epoch 0: Train Loss 0.0802, Train Acc 6.96%, Val Loss 0.0650, Val Acc 11.30%
Epoch 10: Train Loss 0.0399, Train Acc 32.75%, Val Loss 0.0371, Val Acc 37.39%
Epoch 20: Train Loss 0.0255, Train Acc 62.61%, Val Loss 0.0229, Val Acc 66.96%
Epoch 30: Train Loss 0.0168, Train Acc 76.52%, Val Loss 0.0128, Val Acc 83.48%
Epoch 40: Train Loss 0.0115, Train Acc 84.06%, Val Loss 0.0080, Val Acc 90.43%
Epoch 50: Train Loss 0.0102, Train Acc 84.93%, Val Loss 0.0075, Val Acc 90.43%
Epoch 60: Train Loss 0.0065, Train Acc 91.30%, Val Loss 0.0041, Val Acc 94.78%
Epoch 70: Train Loss 0.0060, Train Acc 92.75%, Val Loss 0.0037, Val Acc 94.78%
Epoch 80: Train Loss 0.0055, Train Acc 91.59%, Val Loss 0.0043, Val Acc 93.91%
Epoch 90: Train Loss 0.0040, Train Acc 95.07%, Val Loss 0.0022, Val Acc 97.39%
Test Accuracy: 94.78%
In [422]:
# Visualize predictions
print("Visualizing Training Set Predictions")
train_predictions = visualize_predictions(model, x_train_tensor, y_train, x_train_orig, 'Training Set Predictions')
Visualizing Training Set Predictions
No description has been provided for this image
In [423]:
print("Visualizing Test Set Predictions")
test_predictions = visualize_predictions(model, x_test_tensor, y_test, x_test_orig, 'Test Set Predictions')
Visualizing Test Set Predictions
No description has been provided for this image
In [424]:
print("Visualizing Validation Set Predictions")
val_predictions = visualize_predictions(model, x_val_tensor, y_val, x_val_orig, 'Validation Set Predictions')
Visualizing Validation Set Predictions
No description has been provided for this image
In [425]:
torch.save(model.state_dict(), 'face_recognition_model.pth')
In [ ]: