In [399]:
##Loading
In [400]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import scipy.io as sio
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
random_state=42
np.random.seed(random_state)
torch.manual_seed(random_state)
Out[400]:
<torch._C.Generator at 0x277da6e4e30>
In [401]:
mat_dict = sio.loadmat(file_name='umist_cropped.mat', appendmat=False) # Load MATLAB file into a Python dictionary REF (scipy.io.loadmat): https://docs.scipy.org/doc/scipy/reference/generated/scipy.io.loadmat.html
print(f'mat_dict type: {type(mat_dict)}, mat_dict.keys(): {mat_dict.keys()}')
print(f'\nfacedat type: {type(mat_dict["facedat"])}, facedat shape: {mat_dict["facedat"].shape}')
print(f'dirnames type: {type(mat_dict["dirnames"])}, dirnames shape: {mat_dict["dirnames"].shape}')
face_data = mat_dict['facedat'][0] # Load the face data from the dictionary. Produces numpy.ndarray
print(f'\nNumber of people in images [len(face_data)]: {len(face_data)}\nImage height [len(face_data[0]]: {len(face_data[0])}\nImage width [len(face_data[0][0]]: {len(face_data[0][0])}\nImage total pixels [height X width]: {len(face_data[0]) * len(face_data[0][0])}')
mat_dict type: <class 'dict'>, mat_dict.keys(): dict_keys(['__header__', '__version__', '__globals__', 'facedat', 'dirnames']) facedat type: <class 'numpy.ndarray'>, facedat shape: (1, 20) dirnames type: <class 'numpy.ndarray'>, dirnames shape: (1, 20) Number of people in images [len(face_data)]: 20 Image height [len(face_data[0]]: 112 Image width [len(face_data[0][0]]: 92 Image total pixels [height X width]: 10304
In [402]:
mat_dict
Out[402]:
{'__header__': b'MATLAB 5.0 MAT-file, Platform: GLNX86, Created on: Wed Aug 28 11:38:19 2002',
'__version__': '1.0',
'__globals__': [],
'facedat': array([[array([[[233, 234, 234, ..., 236, 230, 234],
[234, 234, 234, ..., 235, 234, 232],
[234, 234, 234, ..., 236, 233, 234],
...,
[234, 234, 234, ..., 237, 232, 233],
[234, 234, 234, ..., 236, 233, 234],
[234, 234, 233, ..., 236, 234, 233]],
[[234, 234, 234, ..., 236, 232, 233],
[234, 234, 234, ..., 234, 233, 234],
[234, 234, 234, ..., 235, 233, 233],
...,
[234, 234, 234, ..., 236, 230, 234],
[234, 234, 234, ..., 236, 234, 233],
[234, 234, 234, ..., 237, 233, 230]],
[[234, 233, 234, ..., 236, 232, 233],
[234, 234, 233, ..., 235, 230, 234],
[234, 234, 234, ..., 237, 233, 233],
...,
[234, 234, 234, ..., 237, 232, 233],
[234, 233, 234, ..., 239, 232, 233],
[234, 232, 234, ..., 236, 230, 232]],
...,
[[166, 190, 206, ..., 219, 215, 187],
[177, 202, 208, ..., 216, 215, 186],
[194, 204, 205, ..., 213, 218, 187],
...,
[147, 145, 161, ..., 142, 187, 211],
[144, 142, 157, ..., 176, 187, 214],
[141, 142, 155, ..., 187, 187, 215]],
[[ 92, 179, 206, ..., 216, 219, 187],
[113, 197, 206, ..., 215, 220, 187],
[143, 199, 201, ..., 215, 220, 190],
...,
[141, 142, 157, ..., 186, 188, 215],
[142, 138, 152, ..., 190, 187, 216],
[140, 140, 154, ..., 192, 187, 216]],
[[ 65, 104, 204, ..., 215, 218, 184],
[100, 135, 206, ..., 214, 220, 184],
[123, 159, 194, ..., 215, 220, 188],
...,
[141, 138, 148, ..., 192, 187, 215],
[140, 137, 145, ..., 192, 187, 215],
[136, 137, 145, ..., 193, 185, 215]]], dtype=uint8),
array([[[232, 233, 232, ..., 234, 234, 221],
[229, 232, 232, ..., 234, 234, 209],
[230, 230, 232, ..., 234, 234, 140],
...,
[230, 229, 230, ..., 42, 43, 66],
[229, 229, 230, ..., 64, 64, 128],
[229, 230, 230, ..., 91, 105, 201]],
[[229, 232, 232, ..., 233, 234, 207],
[234, 230, 230, ..., 234, 234, 174],
[229, 230, 229, ..., 233, 234, 102],
...,
[229, 232, 232, ..., 41, 39, 57],
[229, 230, 229, ..., 59, 51, 86],
[229, 230, 229, ..., 87, 84, 165]],
[[232, 230, 230, ..., 234, 234, 165],
[230, 229, 230, ..., 234, 233, 108],
[229, 230, 229, ..., 234, 234, 93],
...,
[229, 229, 230, ..., 38, 38, 56],
[229, 230, 229, ..., 45, 44, 63],
[229, 230, 229, ..., 70, 57, 104]],
...,
[[ 83, 85, 97, ..., 200, 200, 200],
[ 87, 92, 93, ..., 200, 200, 201],
[ 92, 97, 92, ..., 201, 197, 197],
...,
[122, 122, 122, ..., 127, 138, 162],
[121, 123, 120, ..., 126, 130, 152],
[123, 123, 122, ..., 125, 126, 143]],
[[ 80, 84, 100, ..., 198, 199, 197],
[ 85, 92, 99, ..., 200, 198, 197],
[ 85, 92, 95, ..., 200, 197, 195],
...,
[122, 122, 125, ..., 125, 130, 158],
[122, 122, 122, ..., 126, 128, 149],
[122, 121, 123, ..., 125, 125, 138]],
[[ 83, 78, 99, ..., 161, 193, 197],
[ 84, 84, 101, ..., 182, 197, 197],
[ 77, 86, 105, ..., 198, 197, 195],
...,
[121, 122, 123, ..., 123, 127, 148],
[122, 122, 123, ..., 126, 125, 140],
[122, 122, 123, ..., 125, 123, 131]]], dtype=uint8),
array([[[233, 230, 234, ..., 234, 234, 234],
[232, 209, 232, ..., 234, 234, 234],
[230, 113, 195, ..., 234, 234, 234],
...,
[230, 232, 232, ..., 234, 234, 234],
[232, 232, 233, ..., 234, 234, 234],
[233, 232, 230, ..., 234, 234, 234]],
[[233, 229, 230, ..., 234, 235, 234],
[230, 197, 229, ..., 234, 234, 234],
[232, 94, 194, ..., 234, 234, 234],
...,
[230, 232, 232, ..., 234, 234, 234],
[232, 229, 230, ..., 234, 234, 234],
[230, 229, 230, ..., 234, 234, 234]],
[[233, 178, 229, ..., 234, 234, 234],
[232, 143, 208, ..., 234, 234, 234],
[233, 57, 135, ..., 234, 234, 234],
...,
[230, 232, 233, ..., 234, 233, 234],
[232, 230, 229, ..., 234, 234, 234],
[233, 229, 230, ..., 233, 234, 234]],
...,
[[170, 194, 194, ..., 205, 201, 201],
[173, 198, 199, ..., 205, 202, 201],
[192, 193, 197, ..., 201, 201, 200],
...,
[123, 114, 71, ..., 135, 127, 147],
[123, 123, 109, ..., 133, 127, 143],
[123, 126, 123, ..., 128, 126, 135]],
[[168, 190, 190, ..., 201, 205, 201],
[170, 193, 197, ..., 204, 204, 201],
[190, 193, 193, ..., 200, 201, 201],
...,
[122, 119, 78, ..., 134, 127, 141],
[122, 122, 112, ..., 129, 127, 138],
[122, 123, 123, ..., 126, 127, 131]],
[[166, 188, 190, ..., 202, 201, 202],
[170, 197, 195, ..., 201, 201, 201],
[186, 192, 194, ..., 200, 201, 201],
...,
[122, 121, 115, ..., 128, 126, 133],
[120, 121, 122, ..., 126, 126, 130],
[121, 122, 123, ..., 122, 126, 127]]], dtype=uint8),
array([[[141, 234, 230, ..., 225, 226, 225],
[116, 232, 233, ..., 225, 228, 226],
[104, 234, 232, ..., 226, 230, 228],
...,
[230, 233, 232, ..., 234, 234, 234],
[232, 232, 232, ..., 234, 234, 234],
[232, 233, 232, ..., 234, 234, 234]],
[[129, 233, 233, ..., 225, 225, 222],
[111, 233, 233, ..., 225, 227, 223],
[100, 234, 233, ..., 229, 229, 225],
...,
[229, 233, 232, ..., 234, 234, 234],
[233, 232, 230, ..., 234, 234, 234],
[229, 229, 232, ..., 234, 234, 234]],
[[116, 232, 233, ..., 226, 225, 222],
[111, 233, 234, ..., 225, 225, 225],
[ 70, 232, 233, ..., 229, 229, 225],
...,
[232, 233, 232, ..., 234, 234, 233],
[230, 230, 232, ..., 234, 234, 234],
[232, 230, 232, ..., 234, 233, 232]],
...,
[[194, 197, 201, ..., 173, 171, 171],
[190, 197, 201, ..., 171, 171, 172],
[ 90, 197, 199, ..., 173, 173, 172],
...,
[129, 126, 128, ..., 183, 180, 180],
[129, 127, 127, ..., 180, 179, 179],
[128, 126, 127, ..., 183, 179, 180]],
[[195, 199, 200, ..., 174, 173, 171],
[192, 199, 200, ..., 173, 173, 172],
[138, 199, 197, ..., 174, 174, 171],
...,
[128, 127, 126, ..., 184, 182, 178],
[129, 127, 127, ..., 183, 180, 178],
[130, 127, 127, ..., 183, 183, 178]],
[[194, 200, 200, ..., 173, 173, 170],
[192, 200, 198, ..., 173, 172, 172],
[147, 197, 197, ..., 173, 174, 170],
...,
[127, 127, 126, ..., 182, 182, 178],
[127, 127, 127, ..., 182, 182, 177],
[127, 128, 125, ..., 183, 183, 171]]], dtype=uint8),
array([[[193, 193, 192, ..., 197, 193, 194],
[192, 193, 192, ..., 197, 192, 194],
[192, 193, 190, ..., 197, 193, 192],
...,
[193, 192, 192, ..., 195, 192, 191],
[193, 192, 192, ..., 194, 193, 192],
[192, 193, 192, ..., 194, 193, 192]],
[[192, 192, 193, ..., 197, 197, 195],
[191, 192, 191, ..., 197, 194, 195],
[191, 192, 192, ..., 197, 194, 193],
...,
[191, 191, 192, ..., 192, 194, 192],
[192, 192, 192, ..., 192, 192, 193],
[192, 192, 194, ..., 192, 194, 193]],
[[192, 192, 192, ..., 195, 197, 195],
[191, 192, 192, ..., 194, 193, 195],
[191, 191, 191, ..., 192, 193, 193],
...,
[192, 192, 192, ..., 192, 194, 193],
[192, 192, 192, ..., 192, 193, 192],
[191, 192, 192, ..., 192, 195, 192]],
...,
[[145, 156, 148, ..., 147, 169, 169],
[166, 172, 169, ..., 148, 174, 173],
[168, 165, 168, ..., 165, 173, 171],
...,
[123, 122, 125, ..., 155, 155, 154],
[125, 123, 127, ..., 155, 155, 155],
[122, 122, 127, ..., 155, 155, 155]],
[[143, 155, 145, ..., 144, 164, 165],
[164, 169, 171, ..., 148, 173, 173],
[166, 168, 170, ..., 164, 172, 171],
...,
[122, 121, 125, ..., 155, 155, 155],
[123, 122, 127, ..., 155, 155, 152],
[122, 122, 126, ..., 155, 154, 154]],
[[144, 154, 145, ..., 144, 164, 164],
[163, 169, 168, ..., 145, 173, 171],
[168, 166, 170, ..., 161, 171, 170],
...,
[121, 120, 123, ..., 155, 155, 154],
[122, 122, 123, ..., 154, 155, 152],
[122, 123, 125, ..., 150, 154, 152]]], dtype=uint8),
array([[[173, 173, 173, ..., 174, 174, 173],
[173, 176, 171, ..., 176, 174, 172],
[173, 173, 170, ..., 176, 173, 173],
...,
[177, 179, 178, ..., 178, 179, 179],
[178, 179, 177, ..., 179, 178, 178],
[178, 178, 176, ..., 178, 179, 178]],
[[173, 173, 174, ..., 173, 173, 173],
[172, 173, 174, ..., 173, 173, 171],
[172, 173, 173, ..., 173, 174, 173],
...,
[176, 178, 179, ..., 178, 178, 177],
[177, 178, 178, ..., 178, 178, 177],
[174, 177, 178, ..., 178, 177, 177]],
[[170, 173, 173, ..., 174, 173, 173],
[171, 173, 173, ..., 173, 173, 174],
[171, 171, 173, ..., 173, 173, 170],
...,
[176, 176, 178, ..., 177, 178, 176],
[176, 178, 178, ..., 178, 178, 178],
[176, 177, 178, ..., 177, 178, 178]],
...,
[[ 29, 34, 35, ..., 87, 86, 147],
[ 29, 34, 34, ..., 67, 66, 115],
[ 29, 30, 30, ..., 43, 43, 63],
...,
[119, 119, 117, ..., 119, 119, 123],
[119, 121, 117, ..., 117, 116, 120],
[117, 119, 119, ..., 117, 117, 119]],
[[ 30, 34, 35, ..., 70, 66, 120],
[ 29, 31, 33, ..., 55, 53, 79],
[ 29, 29, 29, ..., 37, 36, 52],
...,
[117, 120, 117, ..., 117, 117, 120],
[117, 120, 117, ..., 117, 116, 119],
[119, 121, 119, ..., 116, 115, 117]],
[[ 36, 30, 30, ..., 55, 52, 81],
[ 36, 29, 29, ..., 46, 45, 63],
[ 38, 30, 31, ..., 36, 36, 44],
...,
[117, 120, 116, ..., 116, 117, 116],
[117, 120, 117, ..., 117, 114, 117],
[115, 119, 117, ..., 116, 113, 117]]], dtype=uint8),
array([[[179, 177, 178, ..., 168, 145, 148],
[179, 178, 182, ..., 173, 147, 150],
[177, 178, 180, ..., 178, 147, 149],
...,
[163, 162, 164, ..., 164, 164, 164],
[163, 161, 164, ..., 164, 164, 164],
[163, 161, 164, ..., 164, 164, 165]],
[[178, 174, 177, ..., 166, 148, 147],
[178, 178, 178, ..., 176, 150, 147],
[173, 180, 178, ..., 178, 150, 147],
...,
[162, 163, 162, ..., 164, 164, 163],
[161, 164, 161, ..., 164, 164, 163],
[162, 163, 162, ..., 164, 164, 164]],
[[178, 174, 177, ..., 163, 145, 148],
[178, 178, 178, ..., 172, 148, 148],
[177, 180, 179, ..., 177, 147, 148],
...,
[159, 163, 162, ..., 164, 164, 163],
[161, 163, 163, ..., 164, 164, 163],
[159, 163, 161, ..., 164, 164, 164]],
...,
[[ 43, 43, 31, ..., 69, 36, 42],
[ 43, 43, 28, ..., 55, 31, 46],
[ 37, 43, 35, ..., 23, 24, 45],
...,
[113, 113, 113, ..., 115, 114, 113],
[114, 112, 114, ..., 114, 115, 114],
[114, 113, 113, ..., 114, 113, 113]],
[[ 43, 43, 30, ..., 72, 34, 37],
[ 43, 43, 34, ..., 60, 31, 44],
[ 37, 43, 42, ..., 24, 24, 44],
...,
[113, 113, 113, ..., 113, 116, 113],
[113, 113, 113, ..., 113, 114, 113],
[113, 114, 113, ..., 113, 114, 113]],
[[ 43, 49, 31, ..., 92, 35, 29],
[ 43, 44, 31, ..., 77, 30, 37],
[ 41, 43, 38, ..., 38, 23, 43],
...,
[112, 113, 113, ..., 113, 116, 114],
[113, 113, 113, ..., 113, 115, 115],
[113, 114, 113, ..., 113, 113, 114]]], dtype=uint8),
array([[[169, 169, 170, ..., 169, 170, 169],
[168, 169, 170, ..., 170, 169, 169],
[168, 169, 169, ..., 170, 169, 165],
...,
[158, 157, 161, ..., 157, 158, 159],
[158, 158, 158, ..., 157, 159, 159],
[157, 159, 159, ..., 159, 157, 159]],
[[170, 172, 168, ..., 170, 171, 164],
[172, 172, 169, ..., 169, 170, 165],
[168, 170, 168, ..., 170, 172, 165],
...,
[159, 159, 158, ..., 159, 159, 156],
[159, 161, 157, ..., 159, 159, 158],
[159, 162, 158, ..., 159, 161, 158]],
[[169, 169, 169, ..., 169, 170, 166],
[169, 169, 169, ..., 169, 169, 165],
[166, 168, 169, ..., 169, 168, 165],
...,
[157, 159, 159, ..., 157, 158, 159],
[157, 161, 159, ..., 159, 159, 158],
[158, 159, 161, ..., 159, 159, 159]],
...,
[[ 41, 37, 53, ..., 163, 157, 159],
[ 39, 38, 62, ..., 157, 151, 164],
[ 34, 36, 57, ..., 151, 150, 164],
...,
[ 63, 72, 76, ..., 127, 97, 136],
[ 91, 102, 105, ..., 122, 97, 136],
[113, 113, 113, ..., 121, 101, 136]],
[[ 39, 44, 37, ..., 161, 163, 157],
[ 37, 43, 41, ..., 159, 161, 163],
[ 34, 37, 53, ..., 159, 158, 163],
...,
[ 66, 90, 98, ..., 127, 129, 136],
[ 94, 113, 113, ..., 119, 128, 135],
[112, 113, 113, ..., 117, 127, 135]],
[[ 44, 44, 34, ..., 161, 161, 156],
[ 43, 43, 34, ..., 161, 159, 159],
[ 38, 38, 41, ..., 159, 158, 159],
...,
[ 80, 94, 112, ..., 123, 129, 134],
[109, 111, 113, ..., 117, 127, 131],
[113, 113, 112, ..., 116, 125, 134]]], dtype=uint8),
array([[[173, 187, 183, ..., 186, 187, 190],
[161, 186, 183, ..., 184, 184, 187],
[141, 183, 183, ..., 179, 180, 186],
...,
[174, 173, 173, ..., 147, 178, 177],
[176, 173, 173, ..., 156, 178, 174],
[176, 173, 173, ..., 162, 178, 177]],
[[166, 186, 185, ..., 184, 182, 190],
[156, 186, 184, ..., 180, 182, 191],
[134, 184, 184, ..., 177, 178, 187],
...,
[173, 173, 174, ..., 142, 178, 178],
[172, 173, 177, ..., 150, 176, 178],
[173, 174, 176, ..., 156, 177, 178]],
[[152, 185, 183, ..., 179, 182, 188],
[140, 183, 183, ..., 174, 177, 187],
[122, 184, 182, ..., 166, 172, 188],
...,
[174, 173, 173, ..., 142, 178, 177],
[173, 173, 176, ..., 149, 177, 178],
[174, 173, 173, ..., 150, 178, 176]],
...,
[[ 39, 45, 38, ..., 39, 45, 38],
[ 42, 45, 38, ..., 37, 46, 38],
[ 37, 46, 37, ..., 36, 48, 35],
...,
[114, 114, 113, ..., 33, 76, 77],
[113, 114, 113, ..., 43, 97, 84],
[114, 116, 113, ..., 56, 104, 86]],
[[ 38, 44, 38, ..., 37, 44, 38],
[ 42, 45, 35, ..., 34, 46, 38],
[ 35, 45, 38, ..., 33, 48, 34],
...,
[115, 113, 113, ..., 35, 99, 76],
[114, 113, 114, ..., 57, 107, 84],
[116, 113, 114, ..., 76, 113, 86]],
[[ 39, 42, 34, ..., 36, 37, 37],
[ 42, 44, 34, ..., 34, 41, 37],
[ 34, 45, 37, ..., 33, 43, 34],
...,
[113, 113, 113, ..., 34, 101, 85],
[113, 113, 113, ..., 58, 107, 88],
[114, 114, 113, ..., 78, 112, 92]]], dtype=uint8),
array([[[214, 214, 212, ..., 215, 206, 218],
[213, 213, 211, ..., 218, 207, 215],
[212, 213, 214, ..., 215, 204, 215],
...,
[205, 204, 201, ..., 202, 215, 205],
[204, 205, 201, ..., 204, 219, 205],
[202, 204, 204, ..., 204, 218, 206]],
[[215, 215, 214, ..., 216, 204, 215],
[215, 214, 214, ..., 216, 205, 214],
[214, 214, 214, ..., 215, 201, 215],
...,
[206, 206, 205, ..., 206, 214, 204],
[206, 206, 206, ..., 206, 215, 206],
[206, 206, 206, ..., 206, 215, 201]],
[[215, 215, 214, ..., 215, 205, 215],
[215, 215, 215, ..., 215, 204, 214],
[214, 214, 213, ..., 215, 201, 214],
...,
[206, 206, 205, ..., 206, 215, 202],
[206, 206, 205, ..., 206, 216, 201],
[206, 206, 205, ..., 206, 215, 206]],
...,
[[200, 201, 199, ..., 173, 131, 173],
[200, 201, 198, ..., 173, 130, 173],
[199, 201, 198, ..., 174, 131, 174],
...,
[130, 131, 131, ..., 128, 174, 129],
[131, 131, 131, ..., 128, 176, 128],
[130, 131, 131, ..., 128, 172, 128]],
[[201, 201, 199, ..., 172, 129, 172],
[201, 200, 198, ..., 171, 130, 173],
[201, 198, 194, ..., 172, 130, 173],
...,
[131, 131, 130, ..., 128, 173, 129],
[131, 131, 129, ..., 128, 173, 129],
[131, 131, 129, ..., 128, 173, 128]],
[[201, 201, 197, ..., 170, 129, 172],
[201, 201, 197, ..., 173, 128, 172],
[199, 198, 193, ..., 171, 130, 173],
...,
[131, 131, 131, ..., 127, 172, 129],
[133, 130, 127, ..., 129, 174, 129],
[131, 131, 129, ..., 128, 171, 128]]], dtype=uint8),
array([[[207, 206, 211, ..., 218, 220, 219],
[207, 207, 211, ..., 219, 220, 218],
[207, 206, 208, ..., 216, 220, 219],
...,
[206, 202, 206, ..., 208, 208, 209],
[206, 204, 206, ..., 208, 211, 207],
[205, 205, 204, ..., 208, 209, 208]],
[[207, 208, 208, ..., 220, 219, 219],
[206, 209, 208, ..., 219, 219, 220],
[206, 207, 208, ..., 220, 218, 220],
...,
[204, 205, 202, ..., 208, 207, 209],
[204, 206, 204, ..., 211, 208, 211],
[201, 205, 202, ..., 211, 208, 211]],
[[209, 209, 211, ..., 218, 220, 220],
[208, 208, 209, ..., 215, 220, 220],
[209, 206, 208, ..., 218, 218, 219],
...,
[206, 202, 206, ..., 206, 209, 211],
[206, 204, 204, ..., 207, 209, 211],
[204, 204, 206, ..., 208, 211, 211]],
...,
[[192, 195, 192, ..., 206, 208, 206],
[192, 192, 192, ..., 206, 206, 206],
[191, 192, 191, ..., 204, 206, 206],
...,
[133, 134, 131, ..., 137, 137, 136],
[134, 134, 134, ..., 136, 136, 136],
[133, 134, 133, ..., 135, 136, 136]],
[[192, 192, 193, ..., 206, 206, 206],
[193, 192, 192, ..., 205, 206, 205],
[191, 191, 192, ..., 202, 204, 206],
...,
[133, 131, 133, ..., 136, 137, 135],
[131, 133, 131, ..., 136, 136, 138],
[131, 134, 131, ..., 136, 136, 137]],
[[192, 192, 192, ..., 206, 206, 205],
[190, 191, 192, ..., 206, 204, 204],
[188, 190, 188, ..., 206, 204, 205],
...,
[131, 131, 131, ..., 136, 136, 135],
[131, 131, 131, ..., 136, 136, 136],
[131, 131, 131, ..., 136, 136, 136]]], dtype=uint8),
array([[[225, 225, 226, ..., 229, 229, 229],
[223, 225, 225, ..., 229, 229, 229],
[225, 225, 225, ..., 229, 229, 229],
...,
[214, 214, 206, ..., 215, 215, 214],
[216, 215, 215, ..., 216, 215, 214],
[216, 218, 215, ..., 215, 213, 214]],
[[225, 225, 225, ..., 229, 229, 229],
[223, 226, 225, ..., 229, 229, 229],
[221, 223, 225, ..., 229, 229, 229],
...,
[215, 207, 201, ..., 215, 216, 215],
[219, 215, 212, ..., 218, 216, 215],
[220, 219, 213, ..., 215, 216, 215]],
[[225, 225, 225, ..., 229, 229, 228],
[221, 225, 226, ..., 229, 229, 229],
[218, 222, 225, ..., 229, 229, 229],
...,
[215, 201, 176, ..., 216, 215, 216],
[218, 213, 202, ..., 219, 216, 216],
[218, 215, 206, ..., 215, 215, 219]],
...,
[[ 16, 16, 208, ..., 192, 204, 190],
[ 16, 15, 197, ..., 201, 214, 195],
[ 15, 15, 142, ..., 220, 218, 208],
...,
[141, 141, 143, ..., 141, 136, 136],
[141, 141, 143, ..., 141, 140, 136],
[141, 141, 144, ..., 140, 137, 140]],
[[ 15, 20, 215, ..., 192, 202, 187],
[ 15, 15, 212, ..., 198, 213, 193],
[ 15, 16, 173, ..., 215, 218, 204],
...,
[141, 142, 142, ..., 136, 137, 136],
[141, 142, 142, ..., 141, 141, 136],
[144, 141, 143, ..., 137, 141, 136]],
[[ 15, 19, 214, ..., 190, 201, 186],
[ 16, 15, 211, ..., 194, 213, 191],
[ 15, 15, 174, ..., 216, 220, 201],
...,
[142, 142, 140, ..., 137, 140, 136],
[141, 142, 141, ..., 140, 141, 135],
[142, 142, 141, ..., 137, 140, 136]]], dtype=uint8),
array([[[170, 168, 168, ..., 171, 170, 170],
[170, 168, 166, ..., 171, 168, 170],
[172, 170, 169, ..., 169, 170, 169],
...,
[173, 169, 168, ..., 173, 169, 173],
[173, 170, 169, ..., 172, 170, 173],
[173, 171, 169, ..., 171, 169, 171]],
[[170, 165, 169, ..., 170, 171, 169],
[169, 168, 168, ..., 171, 170, 168],
[170, 169, 169, ..., 169, 171, 169],
...,
[171, 168, 170, ..., 173, 172, 171],
[172, 166, 171, ..., 170, 171, 170],
[172, 166, 172, ..., 170, 173, 169]],
[[169, 166, 168, ..., 170, 171, 166],
[168, 166, 169, ..., 169, 169, 168],
[169, 166, 168, ..., 170, 170, 169],
...,
[171, 165, 169, ..., 169, 173, 170],
[171, 166, 168, ..., 172, 171, 169],
[170, 169, 169, ..., 170, 172, 169]],
...,
[[161, 162, 161, ..., 137, 140, 137],
[159, 161, 159, ..., 138, 138, 138],
[159, 161, 161, ..., 140, 142, 141],
...,
[117, 117, 117, ..., 117, 117, 117],
[116, 116, 114, ..., 117, 117, 117],
[117, 116, 115, ..., 116, 117, 117]],
[[159, 161, 163, ..., 137, 141, 136],
[158, 159, 161, ..., 138, 138, 137],
[158, 159, 159, ..., 140, 140, 138],
...,
[116, 117, 116, ..., 117, 117, 117],
[115, 113, 117, ..., 117, 117, 117],
[115, 115, 117, ..., 116, 119, 117]],
[[159, 159, 164, ..., 136, 140, 136],
[159, 159, 159, ..., 137, 138, 137],
[158, 159, 159, ..., 140, 141, 138],
...,
[115, 114, 116, ..., 119, 116, 117],
[114, 115, 117, ..., 117, 117, 119],
[114, 115, 116, ..., 116, 117, 117]]], dtype=uint8),
array([[[135, 136, 136, ..., 162, 164, 164],
[135, 136, 136, ..., 163, 164, 159],
[136, 136, 136, ..., 163, 164, 162],
...,
[162, 163, 161, ..., 134, 136, 136],
[161, 163, 159, ..., 136, 134, 136],
[161, 164, 163, ..., 134, 136, 136]],
[[136, 134, 136, ..., 164, 163, 163],
[136, 134, 136, ..., 164, 163, 164],
[136, 133, 136, ..., 163, 162, 164],
...,
[164, 162, 161, ..., 136, 134, 136],
[163, 162, 163, ..., 136, 133, 136],
[164, 161, 164, ..., 134, 136, 136]],
[[136, 135, 136, ..., 164, 162, 164],
[136, 134, 135, ..., 164, 161, 164],
[136, 133, 136, ..., 163, 162, 164],
...,
[164, 159, 161, ..., 133, 136, 136],
[164, 161, 159, ..., 134, 134, 135],
[164, 161, 161, ..., 134, 136, 135]],
...,
[[108, 108, 80, ..., 136, 131, 133],
[106, 108, 70, ..., 136, 133, 134],
[105, 109, 106, ..., 136, 134, 136],
...,
[113, 109, 112, ..., 65, 83, 92],
[113, 108, 112, ..., 79, 97, 101],
[113, 112, 113, ..., 87, 101, 101]],
[[106, 108, 101, ..., 135, 131, 136],
[104, 111, 95, ..., 136, 136, 136],
[105, 112, 105, ..., 134, 136, 136],
...,
[112, 112, 111, ..., 80, 98, 101],
[112, 113, 112, ..., 98, 105, 107],
[109, 112, 113, ..., 102, 108, 106]],
[[108, 108, 107, ..., 131, 129, 134],
[105, 108, 105, ..., 133, 134, 135],
[105, 109, 101, ..., 134, 135, 136],
...,
[108, 111, 113, ..., 84, 104, 106],
[109, 111, 113, ..., 99, 107, 109],
[111, 111, 114, ..., 101, 109, 107]]], dtype=uint8),
array([[[166, 127, 168, ..., 173, 171, 170],
[168, 149, 168, ..., 172, 172, 170],
[166, 128, 169, ..., 171, 172, 170],
...,
[172, 1, 173, ..., 150, 150, 145],
[171, 1, 173, ..., 150, 149, 145],
[171, 1, 173, ..., 145, 147, 143]],
[[165, 165, 164, ..., 171, 173, 172],
[164, 164, 165, ..., 171, 173, 173],
[165, 164, 166, ..., 170, 173, 173],
...,
[170, 169, 169, ..., 148, 150, 148],
[169, 169, 169, ..., 150, 150, 145],
[169, 171, 171, ..., 145, 147, 145]],
[[165, 164, 164, ..., 170, 173, 169],
[164, 164, 164, ..., 171, 173, 171],
[164, 165, 164, ..., 170, 173, 171],
...,
[169, 170, 171, ..., 148, 150, 147],
[169, 169, 170, ..., 147, 150, 145],
[169, 170, 171, ..., 147, 147, 143]],
...,
[[143, 142, 145, ..., 46, 63, 62],
[144, 147, 133, ..., 60, 66, 63],
[136, 143, 125, ..., 65, 66, 63],
...,
[ 36, 49, 36, ..., 127, 129, 129],
[ 39, 104, 38, ..., 128, 130, 130],
[ 42, 119, 38, ..., 128, 127, 128]],
[[150, 150, 136, ..., 50, 62, 62],
[151, 150, 138, ..., 59, 63, 63],
[148, 149, 130, ..., 62, 66, 62],
...,
[ 34, 74, 35, ..., 129, 127, 127],
[ 38, 116, 36, ..., 127, 131, 127],
[ 49, 121, 39, ..., 127, 127, 127]],
[[154, 152, 143, ..., 55, 59, 60],
[156, 151, 142, ..., 57, 64, 62],
[154, 155, 137, ..., 62, 66, 62],
...,
[ 37, 112, 37, ..., 129, 129, 127],
[ 44, 117, 38, ..., 127, 131, 129],
[ 80, 122, 55, ..., 126, 128, 127]]], dtype=uint8),
array([[[173, 169, 170, ..., 173, 173, 172],
[171, 170, 171, ..., 172, 171, 172],
[171, 170, 171, ..., 173, 171, 171],
...,
[173, 172, 173, ..., 161, 161, 168],
[173, 173, 173, ..., 164, 164, 168],
[174, 172, 173, ..., 166, 164, 171]],
[[171, 171, 169, ..., 173, 173, 173],
[169, 171, 169, ..., 173, 172, 173],
[171, 170, 169, ..., 173, 172, 170],
...,
[173, 173, 172, ..., 151, 158, 169],
[172, 173, 172, ..., 166, 165, 170],
[173, 173, 171, ..., 168, 164, 173]],
[[169, 169, 170, ..., 169, 172, 173],
[170, 169, 171, ..., 172, 171, 173],
[168, 170, 171, ..., 172, 170, 173],
...,
[172, 173, 173, ..., 142, 155, 169],
[172, 173, 174, ..., 161, 163, 169],
[172, 173, 173, ..., 165, 163, 172]],
...,
[[156, 159, 159, ..., 117, 117, 117],
[156, 161, 158, ..., 119, 117, 117],
[155, 159, 157, ..., 117, 117, 117],
...,
[117, 117, 117, ..., 159, 163, 168],
[117, 120, 116, ..., 164, 164, 162],
[117, 117, 117, ..., 159, 164, 150]],
[[158, 159, 158, ..., 117, 119, 117],
[156, 159, 158, ..., 117, 117, 117],
[155, 158, 158, ..., 117, 117, 117],
...,
[117, 117, 117, ..., 159, 162, 165],
[117, 117, 117, ..., 161, 164, 157],
[117, 119, 117, ..., 161, 165, 147]],
[[157, 158, 159, ..., 117, 116, 116],
[155, 158, 159, ..., 117, 114, 117],
[155, 156, 159, ..., 120, 116, 117],
...,
[117, 117, 117, ..., 164, 159, 163],
[117, 117, 119, ..., 163, 164, 155],
[117, 117, 117, ..., 163, 164, 144]]], dtype=uint8),
array([[[148, 149, 147, ..., 149, 1, 150],
[145, 149, 145, ..., 151, 1, 152],
[147, 149, 145, ..., 152, 1, 150],
...,
[172, 173, 173, ..., 173, 1, 173],
[173, 173, 173, ..., 173, 1, 172],
[173, 172, 172, ..., 173, 1, 173]],
[[144, 145, 145, ..., 150, 150, 151],
[145, 145, 145, ..., 155, 150, 151],
[147, 145, 144, ..., 154, 154, 151],
...,
[170, 171, 173, ..., 173, 173, 174],
[170, 170, 173, ..., 176, 173, 173],
[171, 169, 169, ..., 173, 173, 173]],
[[144, 147, 145, ..., 151, 150, 150],
[145, 145, 145, ..., 150, 150, 149],
[147, 147, 147, ..., 151, 155, 149],
...,
[169, 170, 170, ..., 172, 173, 173],
[172, 170, 170, ..., 173, 174, 171],
[172, 169, 169, ..., 173, 173, 173]],
...,
[[133, 131, 133, ..., 137, 136, 135],
[134, 134, 133, ..., 141, 136, 135],
[133, 133, 133, ..., 137, 136, 135],
...,
[120, 77, 105, ..., 150, 147, 143],
[122, 98, 120, ..., 150, 149, 138],
[120, 107, 121, ..., 149, 147, 136]],
[[133, 133, 131, ..., 136, 136, 136],
[133, 131, 131, ..., 137, 136, 135],
[131, 133, 133, ..., 136, 136, 134],
...,
[117, 95, 108, ..., 148, 147, 140],
[119, 114, 119, ..., 150, 150, 135],
[120, 119, 119, ..., 145, 145, 134]],
[[131, 134, 131, ..., 136, 133, 134],
[131, 135, 130, ..., 136, 135, 133],
[133, 136, 131, ..., 136, 134, 131],
...,
[119, 114, 117, ..., 145, 145, 136],
[117, 117, 119, ..., 145, 147, 133],
[117, 120, 117, ..., 145, 145, 131]]], dtype=uint8),
array([[[172, 156, 166, ..., 174, 173, 176],
[174, 173, 173, ..., 177, 173, 176],
[173, 172, 173, ..., 176, 174, 176],
...,
[169, 171, 170, ..., 168, 169, 169],
[169, 169, 169, ..., 172, 168, 171],
[170, 169, 169, ..., 170, 170, 171]],
[[174, 155, 166, ..., 173, 174, 174],
[174, 172, 174, ..., 173, 176, 176],
[174, 174, 174, ..., 173, 176, 173],
...,
[171, 171, 172, ..., 168, 171, 169],
[172, 171, 172, ..., 169, 169, 169],
[170, 171, 169, ..., 169, 170, 169]],
[[173, 154, 166, ..., 172, 177, 173],
[174, 173, 174, ..., 174, 174, 173],
[173, 174, 173, ..., 173, 177, 173],
...,
[170, 168, 169, ..., 166, 169, 165],
[170, 169, 170, ..., 169, 169, 169],
[169, 168, 169, ..., 168, 171, 169]],
...,
[[137, 136, 136, ..., 136, 136, 135],
[136, 136, 136, ..., 136, 136, 135],
[138, 136, 136, ..., 137, 136, 135],
...,
[ 48, 50, 73, ..., 127, 131, 120],
[ 58, 66, 116, ..., 127, 127, 127],
[ 77, 85, 122, ..., 126, 127, 123]],
[[136, 136, 136, ..., 135, 133, 135],
[136, 136, 136, ..., 136, 136, 134],
[136, 135, 136, ..., 136, 134, 134],
...,
[ 62, 57, 115, ..., 127, 130, 116],
[ 98, 97, 123, ..., 125, 126, 127],
[115, 116, 126, ..., 123, 126, 122]],
[[135, 135, 135, ..., 135, 134, 135],
[136, 136, 135, ..., 135, 135, 134],
[136, 136, 135, ..., 135, 135, 134],
...,
[ 72, 85, 117, ..., 123, 122, 108],
[107, 120, 122, ..., 123, 125, 122],
[117, 123, 122, ..., 122, 126, 128]]], dtype=uint8),
array([[[172, 173, 169, ..., 169, 170, 168],
[170, 172, 168, ..., 169, 171, 171],
[169, 170, 169, ..., 173, 169, 170],
...,
[171, 170, 169, ..., 165, 163, 164],
[170, 170, 169, ..., 164, 164, 163],
[169, 169, 168, ..., 168, 164, 166]],
[[171, 170, 171, ..., 168, 172, 170],
[169, 169, 169, ..., 169, 172, 171],
[169, 170, 169, ..., 169, 171, 170],
...,
[169, 168, 171, ..., 164, 165, 166],
[169, 169, 169, ..., 163, 164, 165],
[169, 168, 169, ..., 164, 164, 169]],
[[172, 171, 171, ..., 172, 168, 168],
[172, 172, 168, ..., 172, 169, 170],
[171, 171, 168, ..., 172, 169, 169],
...,
[169, 169, 169, ..., 164, 166, 166],
[170, 169, 170, ..., 164, 166, 165],
[170, 169, 169, ..., 165, 164, 165]],
...,
[[141, 135, 131, ..., 126, 126, 127],
[159, 143, 131, ..., 126, 126, 127],
[159, 162, 136, ..., 127, 126, 127],
...,
[115, 114, 113, ..., 112, 112, 113],
[115, 115, 113, ..., 108, 111, 112],
[117, 117, 113, ..., 112, 108, 113]],
[[137, 134, 133, ..., 126, 127, 125],
[159, 138, 131, ..., 127, 127, 127],
[161, 159, 136, ..., 128, 127, 127],
...,
[115, 114, 113, ..., 113, 113, 113],
[113, 115, 113, ..., 113, 113, 108],
[114, 115, 113, ..., 112, 113, 112]],
[[136, 133, 131, ..., 126, 127, 125],
[159, 137, 131, ..., 128, 127, 127],
[163, 158, 136, ..., 127, 127, 127],
...,
[114, 113, 115, ..., 113, 112, 112],
[116, 114, 113, ..., 112, 113, 108],
[115, 113, 113, ..., 113, 113, 109]]], dtype=uint8),
array([[[164, 164, 164, ..., 39, 49, 15],
[164, 164, 164, ..., 43, 49, 28],
[165, 164, 164, ..., 43, 45, 17],
...,
[166, 164, 154, ..., 115, 113, 15],
[165, 164, 163, ..., 144, 141, 20],
[165, 164, 164, ..., 161, 159, 20]],
[[164, 163, 163, ..., 38, 46, 15],
[164, 163, 163, ..., 39, 45, 28],
[164, 164, 161, ..., 43, 45, 24],
...,
[164, 159, 151, ..., 115, 111, 20],
[165, 163, 159, ..., 144, 141, 21],
[164, 161, 162, ..., 159, 158, 18]],
[[164, 164, 164, ..., 44, 43, 11],
[164, 164, 165, ..., 42, 43, 18],
[164, 163, 164, ..., 41, 42, 29],
...,
[164, 157, 128, ..., 130, 108, 18],
[164, 164, 149, ..., 151, 137, 20],
[166, 164, 155, ..., 158, 157, 21]],
...,
[[154, 150, 150, ..., 155, 155, 93],
[152, 150, 150, ..., 155, 154, 93],
[151, 149, 149, ..., 155, 152, 93],
...,
[115, 117, 117, ..., 113, 113, 101],
[115, 116, 117, ..., 115, 113, 91],
[116, 115, 114, ..., 113, 113, 89]],
[[152, 149, 149, ..., 155, 152, 93],
[150, 149, 150, ..., 155, 152, 92],
[150, 147, 147, ..., 154, 152, 92],
...,
[112, 116, 116, ..., 113, 114, 104],
[113, 116, 115, ..., 113, 113, 85],
[115, 113, 114, ..., 113, 114, 81]],
[[150, 148, 149, ..., 155, 155, 93],
[150, 147, 149, ..., 155, 154, 92],
[150, 145, 147, ..., 155, 154, 94],
...,
[113, 114, 116, ..., 113, 115, 106],
[113, 116, 116, ..., 114, 114, 87],
[114, 115, 115, ..., 113, 115, 86]]], dtype=uint8)]],
dtype=object),
'dirnames': array([[array(['1a'], dtype='<U2'), array(['1b'], dtype='<U2'),
array(['1c'], dtype='<U2'), array(['1d'], dtype='<U2'),
array(['1e'], dtype='<U2'), array(['1f'], dtype='<U2'),
array(['1g'], dtype='<U2'), array(['1h'], dtype='<U2'),
array(['1i'], dtype='<U2'), array(['1j'], dtype='<U2'),
array(['1k'], dtype='<U2'), array(['1l'], dtype='<U2'),
array(['1m'], dtype='<U2'), array(['1n'], dtype='<U2'),
array(['1o'], dtype='<U2'), array(['1p'], dtype='<U2'),
array(['1q'], dtype='<U2'), array(['1r'], dtype='<U2'),
array(['1s'], dtype='<U2'), array(['1t'], dtype='<U2')]],
dtype=object)}
In [403]:
data, label = [], []
for idx, person in enumerate(face_data):
for i in range(person.shape[2]):
img = person[:,:,i]
img = img.reshape(10304)
data.append(img)
label.append(idx)
data, label = np.array(data), np.array(label) # LABEL: Range of label values: 0-19. When using the for-loop and print(f'Label: {label}'), each picture is tagged with the correct label per-person
print(f'data shape: {data.shape}, label shape: {label.shape}, unique people ids: {np.unique(label)}')
data shape: (575, 10304), label shape: (575,), unique people ids: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
In [404]:
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(10, 10))
for i in range(0,20):
fig.add_subplot(4,5,i+1)
plt.imshow(data[i].reshape(112,92), cmap='gray')
plt.xticks([])
plt.yticks([])
person_id_label = f'Person ID: {label[i]}'
plt.title(person_id_label).set_color('green')
In [405]:
##
In [406]:
x_train, x_test, y_train, y_test = train_test_split(data, label, test_size=0.2, random_state=random_state, stratify=label)
x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.25, random_state=random_state, stratify=y_train)
print(f'x_train: {x_train.shape}\ny_train: {y_train.shape}\nx_test: {x_test.shape}\ny_test: {y_test.shape}\nx_val: {x_val.shape}\ny_val: {y_val.shape}')
x_train: (345, 10304) y_train: (345,) x_test: (115, 10304) y_test: (115,) x_val: (115, 10304) y_val: (115,)
In [407]:
# Standard scaling
scaler = StandardScaler()
x_train = scaler.fit_transform(x_train)
x_test = scaler.transform(x_test)
x_val = scaler.transform(x_val)
y_train = torch.eye(20)[torch.LongTensor(y_train)]
y_test = torch.eye(20)[torch.LongTensor(y_test)]
y_val = torch.eye(20)[torch.LongTensor(y_val)]
print(f'x_train: {x_train.shape}\ny_train: {y_train.shape}\nx_test: {x_test.shape}\ny_test: {y_test.shape}\nx_val: {x_val.shape}\ny_val: {y_val.shape}')
x_train: (345, 10304) y_train: torch.Size([345, 20]) x_test: (115, 10304) y_test: torch.Size([115, 20]) x_val: (115, 10304) y_val: torch.Size([115, 20])
In [408]:
pca = PCA(n_components=10, random_state=random_state)
x_train_pca = pca.fit_transform(x_train)
print(f'Explained variance: {sum(pca.explained_variance_ratio_)}')
x_test_pca = pca.transform(x_test)
x_val_pca = pca.transform(x_val)
Explained variance: 0.6492373520222806
In [409]:
print(f'x_train_pca: {x_train_pca.shape}\nx_test_pca: {x_test_pca.shape}\nx_val_pca: {x_val_pca.shape}')
x_train_pca: (345, 10) x_test_pca: (115, 10) x_val_pca: (115, 10)
In [410]:
face_recovered = pca.inverse_transform(x_train_pca) # Face images in the stratified training set, 0 to 344
plt.gray()
plt.xticks([])
plt.yticks([])
plt.title('Sample Image after PCA').set_color('green')
plt.imshow(np.array(face_recovered[111]).reshape(112,92))
Out[410]:
<matplotlib.image.AxesImage at 0x277c53853a0>
In [411]:
kmeans = KMeans(n_clusters=20, random_state=random_state)
x_train_kmeans = kmeans.fit_transform(x_train_pca)
x_test_kmeans = kmeans.transform(x_test_pca)
x_val_kmeans = kmeans.transform(x_val_pca)
x_train_tensor = torch.FloatTensor(x_train_kmeans)
x_test_tensor = torch.FloatTensor(x_test_kmeans)
x_val_tensor = torch.FloatTensor(x_val_kmeans)
D:\anaconda3\Lib\site-packages\sklearn\cluster\_kmeans.py:1429: UserWarning: KMeans is known to have a memory leak on Windows with MKL, when there are less chunks than available threads. You can avoid it by setting the environment variable OMP_NUM_THREADS=2. warnings.warn(
In [426]:
x_train_kmeans
Out[426]:
array([[ 85.34897988, 75.84730645, 76.09119472, ..., 97.77866986,
83.10348576, 51.85576079],
[ 77.5241804 , 131.31719569, 105.1927116 , ..., 108.86916624,
93.96202626, 121.07083419],
[ 98.44347613, 104.5782617 , 84.55857067, ..., 98.99388958,
75.12079416, 88.7045111 ],
...,
[118.08961643, 35.20255626, 94.85801187, ..., 77.82242491,
120.16538824, 94.55763113],
[105.803014 , 92.04468724, 91.22768778, ..., 105.86139586,
83.12872025, 41.19215493],
[ 91.05638838, 99.60399459, 76.91126016, ..., 99.80034206,
87.99432228, 91.87189211]])
In [412]:
print(f'x_train_kmeans: {x_train_kmeans.shape}\nx_test_kmeans: {x_test_kmeans.shape}\nx_val_kmeans: {x_val_kmeans.shape}')
x_train_kmeans: (345, 20) x_test_kmeans: (115, 20) x_val_kmeans: (115, 20)
In [413]:
print(f'x_train_kmeans: {x_train_tensor.shape}\nx_test_kmeans: {x_train_tensor.shape}\nx_val_kmeans: {x_train_tensor.shape}')
x_train_kmeans: torch.Size([345, 20]) x_test_kmeans: torch.Size([345, 20]) x_val_kmeans: torch.Size([345, 20])
In [414]:
print(f'x_train_reshaped: {x_train_reshaped.shape}\ny_train_reshaped: {y_train_reshaped.shape}\nx_test_reshaped: {x_test_reshaped.shape}\ny_test_reshaped: {y_test_reshaped.shape}\nx_val_reshaped: {x_val_reshaped.shape}\ny_val_reshaped: {y_val_reshaped.shape}')
x_train_reshaped: (345, 1, 20) y_train_reshaped: torch.Size([345, 1, 20]) x_test_reshaped: (115, 1, 20) y_test_reshaped: torch.Size([115, 1, 20]) x_val_reshaped: (115, 1, 20) y_val_reshaped: torch.Size([115, 1, 20])
In [415]:
class FaceRecognitionModel(nn.Module):
def __init__(self, input_size, num_classes, units, dropout=False):
super(FaceRecognitionModel, self).__init__()
self.layers = nn.Sequential(
nn.Linear(input_size, 34),
nn.Linear(34, units),
nn.ReLU(),
nn.Dropout(0.2) if dropout else nn.Identity(),
nn.Linear(units, units),
nn.ReLU(),
nn.Dropout(0.2) if dropout else nn.Identity(),
nn.Linear(units, num_classes),
nn.Softmax(dim=1)
)
def forward(self, x):
x = x.view(x.size(0), -1) # Flatten the input
return self.layers(x)
In [416]:
def train_model(model, train_loader, val_loader, criterion, optimizer, epochs=100):
best_val_accuracy = 0
best_model = None
for epoch in range(epochs):
model.train()
train_loss = 0
train_correct = 0
total_train = 0
for inputs, labels in train_loader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item()
_, predicted = torch.max(outputs, 1)
_, true_labels = torch.max(labels, 1)
total_train += labels.size(0)
train_correct += (predicted == true_labels).sum().item()
# Validation
model.eval()
val_loss = 0
val_correct = 0
total_val = 0
with torch.no_grad():
for inputs, labels in val_loader:
outputs = model(inputs)
loss = criterion(outputs, labels)
val_loss += loss.item()
_, predicted = torch.max(outputs, 1)
_, true_labels = torch.max(labels, 1)
total_val += labels.size(0)
val_correct += (predicted == true_labels).sum().item()
train_accuracy = 100 * train_correct / total_train
val_accuracy = 100 * val_correct / total_val
if val_accuracy > best_val_accuracy:
best_val_accuracy = val_accuracy
best_model = model.state_dict()
if epoch % 10 == 0:
print(f'Epoch {epoch}: Train Loss {train_loss/len(train_loader):.4f}, '
f'Train Acc {train_accuracy:.2f}%, '
f'Val Loss {val_loss/len(val_loader):.4f}, '
f'Val Acc {val_accuracy:.2f}%')
return best_model
In [417]:
def visualize_predictions(model, x_data, y_data, x_original, title):
model.eval()
plt.figure(figsize=(15, 10))
# Predict for first 20 samples
predictions = []
with torch.no_grad():
for i in range(20):
inputs = x_data[i].unsqueeze(0)
output = model(inputs)
pred = torch.argmax(output).item()
true_label = torch.argmax(y_data[i]).item()
predictions.append(pred)
plt.subplot(4, 5, i+1)
# Reconstruct and show original image
img = x_original[i].reshape(112, 92)
plt.imshow(img, cmap='gray')
plt.title(f'Pred: {pred}, Actual: {true_lab el}')
plt.axis('off')
plt.suptitle(title)
plt.tight_layout()
plt.show()
return predictions
In [418]:
print(f'x_train_kmeans: {x_train_kmeans.shape}\nx_test_kmeans: {x_test_kmeans.shape}\nx_val_kmeans: {x_val_kmeans.shape}')
x_train_kmeans: (345, 20) x_test_kmeans: (115, 20) x_val_kmeans: (115, 20)
In [419]:
print(x_train.shape[1])
10304
In [420]:
print(x_train_tensor.shape)
torch.Size([345, 20])
In [421]:
# Create datasets and loaders
train_dataset = TensorDataset(x_train_tensor, y_train)
val_dataset = TensorDataset(x_val_tensor, y_val)
test_dataset = TensorDataset(x_test_tensor, y_test)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
test_loader = DataLoader(test_dataset, batch_size=32)
# Hyperparameters
model_params = {
'input_size': x_train_tensor.shape[1],
'num_classes': 20,
'units': 128,
'dropout': True
}
# Initialize model
model = FaceRecognitionModel(**model_params)
# Loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# Train the model
best_model_state = train_model(model, train_loader, val_loader, criterion, optimizer)
# Load best model
model.load_state_dict(best_model_state)
# Evaluate on test set
model.eval()
test_correct = 0
total_test = 0
with torch.no_grad():
for inputs, labels in test_loader:
outputs = model(inputs)
_, predicted = torch.max(outputs, 1)
_, true_labels = torch.max(labels, 1)
total_test += labels.size(0)
test_correct += (predicted == true_labels).sum().item()
test_accuracy = 100 * test_correct / total_test
print(f'Test Accuracy: {test_accuracy:.2f}%')
Epoch 0: Train Loss 0.0802, Train Acc 6.96%, Val Loss 0.0650, Val Acc 11.30% Epoch 10: Train Loss 0.0399, Train Acc 32.75%, Val Loss 0.0371, Val Acc 37.39% Epoch 20: Train Loss 0.0255, Train Acc 62.61%, Val Loss 0.0229, Val Acc 66.96% Epoch 30: Train Loss 0.0168, Train Acc 76.52%, Val Loss 0.0128, Val Acc 83.48% Epoch 40: Train Loss 0.0115, Train Acc 84.06%, Val Loss 0.0080, Val Acc 90.43% Epoch 50: Train Loss 0.0102, Train Acc 84.93%, Val Loss 0.0075, Val Acc 90.43% Epoch 60: Train Loss 0.0065, Train Acc 91.30%, Val Loss 0.0041, Val Acc 94.78% Epoch 70: Train Loss 0.0060, Train Acc 92.75%, Val Loss 0.0037, Val Acc 94.78% Epoch 80: Train Loss 0.0055, Train Acc 91.59%, Val Loss 0.0043, Val Acc 93.91% Epoch 90: Train Loss 0.0040, Train Acc 95.07%, Val Loss 0.0022, Val Acc 97.39% Test Accuracy: 94.78%
In [422]:
# Visualize predictions
print("Visualizing Training Set Predictions")
train_predictions = visualize_predictions(model, x_train_tensor, y_train, x_train_orig, 'Training Set Predictions')
Visualizing Training Set Predictions
In [423]:
print("Visualizing Test Set Predictions")
test_predictions = visualize_predictions(model, x_test_tensor, y_test, x_test_orig, 'Test Set Predictions')
Visualizing Test Set Predictions
In [424]:
print("Visualizing Validation Set Predictions")
val_predictions = visualize_predictions(model, x_val_tensor, y_val, x_val_orig, 'Validation Set Predictions')
Visualizing Validation Set Predictions
In [425]:
torch.save(model.state_dict(), 'face_recognition_model.pth')
In [ ]: