In [399]:
##Loading
In [400]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import scipy.io as sio
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
random_state=42
np.random.seed(random_state)
torch.manual_seed(random_state)
Out[400]:
<torch._C.Generator at 0x277da6e4e30>
In [401]:
mat_dict = sio.loadmat(file_name='umist_cropped.mat', appendmat=False) # Load MATLAB file into a Python dictionary REF (scipy.io.loadmat): https://docs.scipy.org/doc/scipy/reference/generated/scipy.io.loadmat.html
print(f'mat_dict type: {type(mat_dict)}, mat_dict.keys(): {mat_dict.keys()}')
print(f'\nfacedat type: {type(mat_dict["facedat"])}, facedat shape: {mat_dict["facedat"].shape}')
print(f'dirnames type: {type(mat_dict["dirnames"])}, dirnames shape: {mat_dict["dirnames"].shape}')
face_data = mat_dict['facedat'][0] # Load the face data from the dictionary. Produces numpy.ndarray
print(f'\nNumber of people in images [len(face_data)]: {len(face_data)}\nImage height [len(face_data[0]]: {len(face_data[0])}\nImage width [len(face_data[0][0]]: {len(face_data[0][0])}\nImage total pixels [height X width]: {len(face_data[0]) * len(face_data[0][0])}')
mat_dict type: <class 'dict'>, mat_dict.keys(): dict_keys(['__header__', '__version__', '__globals__', 'facedat', 'dirnames']) facedat type: <class 'numpy.ndarray'>, facedat shape: (1, 20) dirnames type: <class 'numpy.ndarray'>, dirnames shape: (1, 20) Number of people in images [len(face_data)]: 20 Image height [len(face_data[0]]: 112 Image width [len(face_data[0][0]]: 92 Image total pixels [height X width]: 10304
In [402]:
mat_dict
Out[402]:
{'__header__': b'MATLAB 5.0 MAT-file, Platform: GLNX86, Created on: Wed Aug 28 11:38:19 2002', '__version__': '1.0', '__globals__': [], 'facedat': array([[array([[[233, 234, 234, ..., 236, 230, 234], [234, 234, 234, ..., 235, 234, 232], [234, 234, 234, ..., 236, 233, 234], ..., [234, 234, 234, ..., 237, 232, 233], [234, 234, 234, ..., 236, 233, 234], [234, 234, 233, ..., 236, 234, 233]], [[234, 234, 234, ..., 236, 232, 233], [234, 234, 234, ..., 234, 233, 234], [234, 234, 234, ..., 235, 233, 233], ..., [234, 234, 234, ..., 236, 230, 234], [234, 234, 234, ..., 236, 234, 233], [234, 234, 234, ..., 237, 233, 230]], [[234, 233, 234, ..., 236, 232, 233], [234, 234, 233, ..., 235, 230, 234], [234, 234, 234, ..., 237, 233, 233], ..., [234, 234, 234, ..., 237, 232, 233], [234, 233, 234, ..., 239, 232, 233], [234, 232, 234, ..., 236, 230, 232]], ..., [[166, 190, 206, ..., 219, 215, 187], [177, 202, 208, ..., 216, 215, 186], [194, 204, 205, ..., 213, 218, 187], ..., [147, 145, 161, ..., 142, 187, 211], [144, 142, 157, ..., 176, 187, 214], [141, 142, 155, ..., 187, 187, 215]], [[ 92, 179, 206, ..., 216, 219, 187], [113, 197, 206, ..., 215, 220, 187], [143, 199, 201, ..., 215, 220, 190], ..., [141, 142, 157, ..., 186, 188, 215], [142, 138, 152, ..., 190, 187, 216], [140, 140, 154, ..., 192, 187, 216]], [[ 65, 104, 204, ..., 215, 218, 184], [100, 135, 206, ..., 214, 220, 184], [123, 159, 194, ..., 215, 220, 188], ..., [141, 138, 148, ..., 192, 187, 215], [140, 137, 145, ..., 192, 187, 215], [136, 137, 145, ..., 193, 185, 215]]], dtype=uint8), array([[[232, 233, 232, ..., 234, 234, 221], [229, 232, 232, ..., 234, 234, 209], [230, 230, 232, ..., 234, 234, 140], ..., [230, 229, 230, ..., 42, 43, 66], [229, 229, 230, ..., 64, 64, 128], [229, 230, 230, ..., 91, 105, 201]], [[229, 232, 232, ..., 233, 234, 207], [234, 230, 230, ..., 234, 234, 174], [229, 230, 229, ..., 233, 234, 102], ..., [229, 232, 232, ..., 41, 39, 57], [229, 230, 229, ..., 59, 51, 86], [229, 230, 229, ..., 87, 84, 165]], [[232, 230, 230, ..., 234, 234, 165], [230, 229, 230, ..., 234, 233, 108], [229, 230, 229, ..., 234, 234, 93], ..., [229, 229, 230, ..., 38, 38, 56], [229, 230, 229, ..., 45, 44, 63], [229, 230, 229, ..., 70, 57, 104]], ..., [[ 83, 85, 97, ..., 200, 200, 200], [ 87, 92, 93, ..., 200, 200, 201], [ 92, 97, 92, ..., 201, 197, 197], ..., [122, 122, 122, ..., 127, 138, 162], [121, 123, 120, ..., 126, 130, 152], [123, 123, 122, ..., 125, 126, 143]], [[ 80, 84, 100, ..., 198, 199, 197], [ 85, 92, 99, ..., 200, 198, 197], [ 85, 92, 95, ..., 200, 197, 195], ..., [122, 122, 125, ..., 125, 130, 158], [122, 122, 122, ..., 126, 128, 149], [122, 121, 123, ..., 125, 125, 138]], [[ 83, 78, 99, ..., 161, 193, 197], [ 84, 84, 101, ..., 182, 197, 197], [ 77, 86, 105, ..., 198, 197, 195], ..., [121, 122, 123, ..., 123, 127, 148], [122, 122, 123, ..., 126, 125, 140], [122, 122, 123, ..., 125, 123, 131]]], dtype=uint8), array([[[233, 230, 234, ..., 234, 234, 234], [232, 209, 232, ..., 234, 234, 234], [230, 113, 195, ..., 234, 234, 234], ..., [230, 232, 232, ..., 234, 234, 234], [232, 232, 233, ..., 234, 234, 234], [233, 232, 230, ..., 234, 234, 234]], [[233, 229, 230, ..., 234, 235, 234], [230, 197, 229, ..., 234, 234, 234], [232, 94, 194, ..., 234, 234, 234], ..., [230, 232, 232, ..., 234, 234, 234], [232, 229, 230, ..., 234, 234, 234], [230, 229, 230, ..., 234, 234, 234]], [[233, 178, 229, ..., 234, 234, 234], [232, 143, 208, ..., 234, 234, 234], [233, 57, 135, ..., 234, 234, 234], ..., [230, 232, 233, ..., 234, 233, 234], [232, 230, 229, ..., 234, 234, 234], [233, 229, 230, ..., 233, 234, 234]], ..., [[170, 194, 194, ..., 205, 201, 201], [173, 198, 199, ..., 205, 202, 201], [192, 193, 197, ..., 201, 201, 200], ..., [123, 114, 71, ..., 135, 127, 147], [123, 123, 109, ..., 133, 127, 143], [123, 126, 123, ..., 128, 126, 135]], [[168, 190, 190, ..., 201, 205, 201], [170, 193, 197, ..., 204, 204, 201], [190, 193, 193, ..., 200, 201, 201], ..., [122, 119, 78, ..., 134, 127, 141], [122, 122, 112, ..., 129, 127, 138], [122, 123, 123, ..., 126, 127, 131]], [[166, 188, 190, ..., 202, 201, 202], [170, 197, 195, ..., 201, 201, 201], [186, 192, 194, ..., 200, 201, 201], ..., [122, 121, 115, ..., 128, 126, 133], [120, 121, 122, ..., 126, 126, 130], [121, 122, 123, ..., 122, 126, 127]]], dtype=uint8), array([[[141, 234, 230, ..., 225, 226, 225], [116, 232, 233, ..., 225, 228, 226], [104, 234, 232, ..., 226, 230, 228], ..., [230, 233, 232, ..., 234, 234, 234], [232, 232, 232, ..., 234, 234, 234], [232, 233, 232, ..., 234, 234, 234]], [[129, 233, 233, ..., 225, 225, 222], [111, 233, 233, ..., 225, 227, 223], [100, 234, 233, ..., 229, 229, 225], ..., [229, 233, 232, ..., 234, 234, 234], [233, 232, 230, ..., 234, 234, 234], [229, 229, 232, ..., 234, 234, 234]], [[116, 232, 233, ..., 226, 225, 222], [111, 233, 234, ..., 225, 225, 225], [ 70, 232, 233, ..., 229, 229, 225], ..., [232, 233, 232, ..., 234, 234, 233], [230, 230, 232, ..., 234, 234, 234], [232, 230, 232, ..., 234, 233, 232]], ..., [[194, 197, 201, ..., 173, 171, 171], [190, 197, 201, ..., 171, 171, 172], [ 90, 197, 199, ..., 173, 173, 172], ..., [129, 126, 128, ..., 183, 180, 180], [129, 127, 127, ..., 180, 179, 179], [128, 126, 127, ..., 183, 179, 180]], [[195, 199, 200, ..., 174, 173, 171], [192, 199, 200, ..., 173, 173, 172], [138, 199, 197, ..., 174, 174, 171], ..., [128, 127, 126, ..., 184, 182, 178], [129, 127, 127, ..., 183, 180, 178], [130, 127, 127, ..., 183, 183, 178]], [[194, 200, 200, ..., 173, 173, 170], [192, 200, 198, ..., 173, 172, 172], [147, 197, 197, ..., 173, 174, 170], ..., [127, 127, 126, ..., 182, 182, 178], [127, 127, 127, ..., 182, 182, 177], [127, 128, 125, ..., 183, 183, 171]]], dtype=uint8), array([[[193, 193, 192, ..., 197, 193, 194], [192, 193, 192, ..., 197, 192, 194], [192, 193, 190, ..., 197, 193, 192], ..., [193, 192, 192, ..., 195, 192, 191], [193, 192, 192, ..., 194, 193, 192], [192, 193, 192, ..., 194, 193, 192]], [[192, 192, 193, ..., 197, 197, 195], [191, 192, 191, ..., 197, 194, 195], [191, 192, 192, ..., 197, 194, 193], ..., [191, 191, 192, ..., 192, 194, 192], [192, 192, 192, ..., 192, 192, 193], [192, 192, 194, ..., 192, 194, 193]], [[192, 192, 192, ..., 195, 197, 195], [191, 192, 192, ..., 194, 193, 195], [191, 191, 191, ..., 192, 193, 193], ..., [192, 192, 192, ..., 192, 194, 193], [192, 192, 192, ..., 192, 193, 192], [191, 192, 192, ..., 192, 195, 192]], ..., [[145, 156, 148, ..., 147, 169, 169], [166, 172, 169, ..., 148, 174, 173], [168, 165, 168, ..., 165, 173, 171], ..., [123, 122, 125, ..., 155, 155, 154], [125, 123, 127, ..., 155, 155, 155], [122, 122, 127, ..., 155, 155, 155]], [[143, 155, 145, ..., 144, 164, 165], [164, 169, 171, ..., 148, 173, 173], [166, 168, 170, ..., 164, 172, 171], ..., [122, 121, 125, ..., 155, 155, 155], [123, 122, 127, ..., 155, 155, 152], [122, 122, 126, ..., 155, 154, 154]], [[144, 154, 145, ..., 144, 164, 164], [163, 169, 168, ..., 145, 173, 171], [168, 166, 170, ..., 161, 171, 170], ..., [121, 120, 123, ..., 155, 155, 154], [122, 122, 123, ..., 154, 155, 152], [122, 123, 125, ..., 150, 154, 152]]], dtype=uint8), array([[[173, 173, 173, ..., 174, 174, 173], [173, 176, 171, ..., 176, 174, 172], [173, 173, 170, ..., 176, 173, 173], ..., [177, 179, 178, ..., 178, 179, 179], [178, 179, 177, ..., 179, 178, 178], [178, 178, 176, ..., 178, 179, 178]], [[173, 173, 174, ..., 173, 173, 173], [172, 173, 174, ..., 173, 173, 171], [172, 173, 173, ..., 173, 174, 173], ..., [176, 178, 179, ..., 178, 178, 177], [177, 178, 178, ..., 178, 178, 177], [174, 177, 178, ..., 178, 177, 177]], [[170, 173, 173, ..., 174, 173, 173], [171, 173, 173, ..., 173, 173, 174], [171, 171, 173, ..., 173, 173, 170], ..., [176, 176, 178, ..., 177, 178, 176], [176, 178, 178, ..., 178, 178, 178], [176, 177, 178, ..., 177, 178, 178]], ..., [[ 29, 34, 35, ..., 87, 86, 147], [ 29, 34, 34, ..., 67, 66, 115], [ 29, 30, 30, ..., 43, 43, 63], ..., [119, 119, 117, ..., 119, 119, 123], [119, 121, 117, ..., 117, 116, 120], [117, 119, 119, ..., 117, 117, 119]], [[ 30, 34, 35, ..., 70, 66, 120], [ 29, 31, 33, ..., 55, 53, 79], [ 29, 29, 29, ..., 37, 36, 52], ..., [117, 120, 117, ..., 117, 117, 120], [117, 120, 117, ..., 117, 116, 119], [119, 121, 119, ..., 116, 115, 117]], [[ 36, 30, 30, ..., 55, 52, 81], [ 36, 29, 29, ..., 46, 45, 63], [ 38, 30, 31, ..., 36, 36, 44], ..., [117, 120, 116, ..., 116, 117, 116], [117, 120, 117, ..., 117, 114, 117], [115, 119, 117, ..., 116, 113, 117]]], dtype=uint8), array([[[179, 177, 178, ..., 168, 145, 148], [179, 178, 182, ..., 173, 147, 150], [177, 178, 180, ..., 178, 147, 149], ..., [163, 162, 164, ..., 164, 164, 164], [163, 161, 164, ..., 164, 164, 164], [163, 161, 164, ..., 164, 164, 165]], [[178, 174, 177, ..., 166, 148, 147], [178, 178, 178, ..., 176, 150, 147], [173, 180, 178, ..., 178, 150, 147], ..., [162, 163, 162, ..., 164, 164, 163], [161, 164, 161, ..., 164, 164, 163], [162, 163, 162, ..., 164, 164, 164]], [[178, 174, 177, ..., 163, 145, 148], [178, 178, 178, ..., 172, 148, 148], [177, 180, 179, ..., 177, 147, 148], ..., [159, 163, 162, ..., 164, 164, 163], [161, 163, 163, ..., 164, 164, 163], [159, 163, 161, ..., 164, 164, 164]], ..., [[ 43, 43, 31, ..., 69, 36, 42], [ 43, 43, 28, ..., 55, 31, 46], [ 37, 43, 35, ..., 23, 24, 45], ..., [113, 113, 113, ..., 115, 114, 113], [114, 112, 114, ..., 114, 115, 114], [114, 113, 113, ..., 114, 113, 113]], [[ 43, 43, 30, ..., 72, 34, 37], [ 43, 43, 34, ..., 60, 31, 44], [ 37, 43, 42, ..., 24, 24, 44], ..., [113, 113, 113, ..., 113, 116, 113], [113, 113, 113, ..., 113, 114, 113], [113, 114, 113, ..., 113, 114, 113]], [[ 43, 49, 31, ..., 92, 35, 29], [ 43, 44, 31, ..., 77, 30, 37], [ 41, 43, 38, ..., 38, 23, 43], ..., [112, 113, 113, ..., 113, 116, 114], [113, 113, 113, ..., 113, 115, 115], [113, 114, 113, ..., 113, 113, 114]]], dtype=uint8), array([[[169, 169, 170, ..., 169, 170, 169], [168, 169, 170, ..., 170, 169, 169], [168, 169, 169, ..., 170, 169, 165], ..., [158, 157, 161, ..., 157, 158, 159], [158, 158, 158, ..., 157, 159, 159], [157, 159, 159, ..., 159, 157, 159]], [[170, 172, 168, ..., 170, 171, 164], [172, 172, 169, ..., 169, 170, 165], [168, 170, 168, ..., 170, 172, 165], ..., [159, 159, 158, ..., 159, 159, 156], [159, 161, 157, ..., 159, 159, 158], [159, 162, 158, ..., 159, 161, 158]], [[169, 169, 169, ..., 169, 170, 166], [169, 169, 169, ..., 169, 169, 165], [166, 168, 169, ..., 169, 168, 165], ..., [157, 159, 159, ..., 157, 158, 159], [157, 161, 159, ..., 159, 159, 158], [158, 159, 161, ..., 159, 159, 159]], ..., [[ 41, 37, 53, ..., 163, 157, 159], [ 39, 38, 62, ..., 157, 151, 164], [ 34, 36, 57, ..., 151, 150, 164], ..., [ 63, 72, 76, ..., 127, 97, 136], [ 91, 102, 105, ..., 122, 97, 136], [113, 113, 113, ..., 121, 101, 136]], [[ 39, 44, 37, ..., 161, 163, 157], [ 37, 43, 41, ..., 159, 161, 163], [ 34, 37, 53, ..., 159, 158, 163], ..., [ 66, 90, 98, ..., 127, 129, 136], [ 94, 113, 113, ..., 119, 128, 135], [112, 113, 113, ..., 117, 127, 135]], [[ 44, 44, 34, ..., 161, 161, 156], [ 43, 43, 34, ..., 161, 159, 159], [ 38, 38, 41, ..., 159, 158, 159], ..., [ 80, 94, 112, ..., 123, 129, 134], [109, 111, 113, ..., 117, 127, 131], [113, 113, 112, ..., 116, 125, 134]]], dtype=uint8), array([[[173, 187, 183, ..., 186, 187, 190], [161, 186, 183, ..., 184, 184, 187], [141, 183, 183, ..., 179, 180, 186], ..., [174, 173, 173, ..., 147, 178, 177], [176, 173, 173, ..., 156, 178, 174], [176, 173, 173, ..., 162, 178, 177]], [[166, 186, 185, ..., 184, 182, 190], [156, 186, 184, ..., 180, 182, 191], [134, 184, 184, ..., 177, 178, 187], ..., [173, 173, 174, ..., 142, 178, 178], [172, 173, 177, ..., 150, 176, 178], [173, 174, 176, ..., 156, 177, 178]], [[152, 185, 183, ..., 179, 182, 188], [140, 183, 183, ..., 174, 177, 187], [122, 184, 182, ..., 166, 172, 188], ..., [174, 173, 173, ..., 142, 178, 177], [173, 173, 176, ..., 149, 177, 178], [174, 173, 173, ..., 150, 178, 176]], ..., [[ 39, 45, 38, ..., 39, 45, 38], [ 42, 45, 38, ..., 37, 46, 38], [ 37, 46, 37, ..., 36, 48, 35], ..., [114, 114, 113, ..., 33, 76, 77], [113, 114, 113, ..., 43, 97, 84], [114, 116, 113, ..., 56, 104, 86]], [[ 38, 44, 38, ..., 37, 44, 38], [ 42, 45, 35, ..., 34, 46, 38], [ 35, 45, 38, ..., 33, 48, 34], ..., [115, 113, 113, ..., 35, 99, 76], [114, 113, 114, ..., 57, 107, 84], [116, 113, 114, ..., 76, 113, 86]], [[ 39, 42, 34, ..., 36, 37, 37], [ 42, 44, 34, ..., 34, 41, 37], [ 34, 45, 37, ..., 33, 43, 34], ..., [113, 113, 113, ..., 34, 101, 85], [113, 113, 113, ..., 58, 107, 88], [114, 114, 113, ..., 78, 112, 92]]], dtype=uint8), array([[[214, 214, 212, ..., 215, 206, 218], [213, 213, 211, ..., 218, 207, 215], [212, 213, 214, ..., 215, 204, 215], ..., [205, 204, 201, ..., 202, 215, 205], [204, 205, 201, ..., 204, 219, 205], [202, 204, 204, ..., 204, 218, 206]], [[215, 215, 214, ..., 216, 204, 215], [215, 214, 214, ..., 216, 205, 214], [214, 214, 214, ..., 215, 201, 215], ..., [206, 206, 205, ..., 206, 214, 204], [206, 206, 206, ..., 206, 215, 206], [206, 206, 206, ..., 206, 215, 201]], [[215, 215, 214, ..., 215, 205, 215], [215, 215, 215, ..., 215, 204, 214], [214, 214, 213, ..., 215, 201, 214], ..., [206, 206, 205, ..., 206, 215, 202], [206, 206, 205, ..., 206, 216, 201], [206, 206, 205, ..., 206, 215, 206]], ..., [[200, 201, 199, ..., 173, 131, 173], [200, 201, 198, ..., 173, 130, 173], [199, 201, 198, ..., 174, 131, 174], ..., [130, 131, 131, ..., 128, 174, 129], [131, 131, 131, ..., 128, 176, 128], [130, 131, 131, ..., 128, 172, 128]], [[201, 201, 199, ..., 172, 129, 172], [201, 200, 198, ..., 171, 130, 173], [201, 198, 194, ..., 172, 130, 173], ..., [131, 131, 130, ..., 128, 173, 129], [131, 131, 129, ..., 128, 173, 129], [131, 131, 129, ..., 128, 173, 128]], [[201, 201, 197, ..., 170, 129, 172], [201, 201, 197, ..., 173, 128, 172], [199, 198, 193, ..., 171, 130, 173], ..., [131, 131, 131, ..., 127, 172, 129], [133, 130, 127, ..., 129, 174, 129], [131, 131, 129, ..., 128, 171, 128]]], dtype=uint8), array([[[207, 206, 211, ..., 218, 220, 219], [207, 207, 211, ..., 219, 220, 218], [207, 206, 208, ..., 216, 220, 219], ..., [206, 202, 206, ..., 208, 208, 209], [206, 204, 206, ..., 208, 211, 207], [205, 205, 204, ..., 208, 209, 208]], [[207, 208, 208, ..., 220, 219, 219], [206, 209, 208, ..., 219, 219, 220], [206, 207, 208, ..., 220, 218, 220], ..., [204, 205, 202, ..., 208, 207, 209], [204, 206, 204, ..., 211, 208, 211], [201, 205, 202, ..., 211, 208, 211]], [[209, 209, 211, ..., 218, 220, 220], [208, 208, 209, ..., 215, 220, 220], [209, 206, 208, ..., 218, 218, 219], ..., [206, 202, 206, ..., 206, 209, 211], [206, 204, 204, ..., 207, 209, 211], [204, 204, 206, ..., 208, 211, 211]], ..., [[192, 195, 192, ..., 206, 208, 206], [192, 192, 192, ..., 206, 206, 206], [191, 192, 191, ..., 204, 206, 206], ..., [133, 134, 131, ..., 137, 137, 136], [134, 134, 134, ..., 136, 136, 136], [133, 134, 133, ..., 135, 136, 136]], [[192, 192, 193, ..., 206, 206, 206], [193, 192, 192, ..., 205, 206, 205], [191, 191, 192, ..., 202, 204, 206], ..., [133, 131, 133, ..., 136, 137, 135], [131, 133, 131, ..., 136, 136, 138], [131, 134, 131, ..., 136, 136, 137]], [[192, 192, 192, ..., 206, 206, 205], [190, 191, 192, ..., 206, 204, 204], [188, 190, 188, ..., 206, 204, 205], ..., [131, 131, 131, ..., 136, 136, 135], [131, 131, 131, ..., 136, 136, 136], [131, 131, 131, ..., 136, 136, 136]]], dtype=uint8), array([[[225, 225, 226, ..., 229, 229, 229], [223, 225, 225, ..., 229, 229, 229], [225, 225, 225, ..., 229, 229, 229], ..., [214, 214, 206, ..., 215, 215, 214], [216, 215, 215, ..., 216, 215, 214], [216, 218, 215, ..., 215, 213, 214]], [[225, 225, 225, ..., 229, 229, 229], [223, 226, 225, ..., 229, 229, 229], [221, 223, 225, ..., 229, 229, 229], ..., [215, 207, 201, ..., 215, 216, 215], [219, 215, 212, ..., 218, 216, 215], [220, 219, 213, ..., 215, 216, 215]], [[225, 225, 225, ..., 229, 229, 228], [221, 225, 226, ..., 229, 229, 229], [218, 222, 225, ..., 229, 229, 229], ..., [215, 201, 176, ..., 216, 215, 216], [218, 213, 202, ..., 219, 216, 216], [218, 215, 206, ..., 215, 215, 219]], ..., [[ 16, 16, 208, ..., 192, 204, 190], [ 16, 15, 197, ..., 201, 214, 195], [ 15, 15, 142, ..., 220, 218, 208], ..., [141, 141, 143, ..., 141, 136, 136], [141, 141, 143, ..., 141, 140, 136], [141, 141, 144, ..., 140, 137, 140]], [[ 15, 20, 215, ..., 192, 202, 187], [ 15, 15, 212, ..., 198, 213, 193], [ 15, 16, 173, ..., 215, 218, 204], ..., [141, 142, 142, ..., 136, 137, 136], [141, 142, 142, ..., 141, 141, 136], [144, 141, 143, ..., 137, 141, 136]], [[ 15, 19, 214, ..., 190, 201, 186], [ 16, 15, 211, ..., 194, 213, 191], [ 15, 15, 174, ..., 216, 220, 201], ..., [142, 142, 140, ..., 137, 140, 136], [141, 142, 141, ..., 140, 141, 135], [142, 142, 141, ..., 137, 140, 136]]], dtype=uint8), array([[[170, 168, 168, ..., 171, 170, 170], [170, 168, 166, ..., 171, 168, 170], [172, 170, 169, ..., 169, 170, 169], ..., [173, 169, 168, ..., 173, 169, 173], [173, 170, 169, ..., 172, 170, 173], [173, 171, 169, ..., 171, 169, 171]], [[170, 165, 169, ..., 170, 171, 169], [169, 168, 168, ..., 171, 170, 168], [170, 169, 169, ..., 169, 171, 169], ..., [171, 168, 170, ..., 173, 172, 171], [172, 166, 171, ..., 170, 171, 170], [172, 166, 172, ..., 170, 173, 169]], [[169, 166, 168, ..., 170, 171, 166], [168, 166, 169, ..., 169, 169, 168], [169, 166, 168, ..., 170, 170, 169], ..., [171, 165, 169, ..., 169, 173, 170], [171, 166, 168, ..., 172, 171, 169], [170, 169, 169, ..., 170, 172, 169]], ..., [[161, 162, 161, ..., 137, 140, 137], [159, 161, 159, ..., 138, 138, 138], [159, 161, 161, ..., 140, 142, 141], ..., [117, 117, 117, ..., 117, 117, 117], [116, 116, 114, ..., 117, 117, 117], [117, 116, 115, ..., 116, 117, 117]], [[159, 161, 163, ..., 137, 141, 136], [158, 159, 161, ..., 138, 138, 137], [158, 159, 159, ..., 140, 140, 138], ..., [116, 117, 116, ..., 117, 117, 117], [115, 113, 117, ..., 117, 117, 117], [115, 115, 117, ..., 116, 119, 117]], [[159, 159, 164, ..., 136, 140, 136], [159, 159, 159, ..., 137, 138, 137], [158, 159, 159, ..., 140, 141, 138], ..., [115, 114, 116, ..., 119, 116, 117], [114, 115, 117, ..., 117, 117, 119], [114, 115, 116, ..., 116, 117, 117]]], dtype=uint8), array([[[135, 136, 136, ..., 162, 164, 164], [135, 136, 136, ..., 163, 164, 159], [136, 136, 136, ..., 163, 164, 162], ..., [162, 163, 161, ..., 134, 136, 136], [161, 163, 159, ..., 136, 134, 136], [161, 164, 163, ..., 134, 136, 136]], [[136, 134, 136, ..., 164, 163, 163], [136, 134, 136, ..., 164, 163, 164], [136, 133, 136, ..., 163, 162, 164], ..., [164, 162, 161, ..., 136, 134, 136], [163, 162, 163, ..., 136, 133, 136], [164, 161, 164, ..., 134, 136, 136]], [[136, 135, 136, ..., 164, 162, 164], [136, 134, 135, ..., 164, 161, 164], [136, 133, 136, ..., 163, 162, 164], ..., [164, 159, 161, ..., 133, 136, 136], [164, 161, 159, ..., 134, 134, 135], [164, 161, 161, ..., 134, 136, 135]], ..., [[108, 108, 80, ..., 136, 131, 133], [106, 108, 70, ..., 136, 133, 134], [105, 109, 106, ..., 136, 134, 136], ..., [113, 109, 112, ..., 65, 83, 92], [113, 108, 112, ..., 79, 97, 101], [113, 112, 113, ..., 87, 101, 101]], [[106, 108, 101, ..., 135, 131, 136], [104, 111, 95, ..., 136, 136, 136], [105, 112, 105, ..., 134, 136, 136], ..., [112, 112, 111, ..., 80, 98, 101], [112, 113, 112, ..., 98, 105, 107], [109, 112, 113, ..., 102, 108, 106]], [[108, 108, 107, ..., 131, 129, 134], [105, 108, 105, ..., 133, 134, 135], [105, 109, 101, ..., 134, 135, 136], ..., [108, 111, 113, ..., 84, 104, 106], [109, 111, 113, ..., 99, 107, 109], [111, 111, 114, ..., 101, 109, 107]]], dtype=uint8), array([[[166, 127, 168, ..., 173, 171, 170], [168, 149, 168, ..., 172, 172, 170], [166, 128, 169, ..., 171, 172, 170], ..., [172, 1, 173, ..., 150, 150, 145], [171, 1, 173, ..., 150, 149, 145], [171, 1, 173, ..., 145, 147, 143]], [[165, 165, 164, ..., 171, 173, 172], [164, 164, 165, ..., 171, 173, 173], [165, 164, 166, ..., 170, 173, 173], ..., [170, 169, 169, ..., 148, 150, 148], [169, 169, 169, ..., 150, 150, 145], [169, 171, 171, ..., 145, 147, 145]], [[165, 164, 164, ..., 170, 173, 169], [164, 164, 164, ..., 171, 173, 171], [164, 165, 164, ..., 170, 173, 171], ..., [169, 170, 171, ..., 148, 150, 147], [169, 169, 170, ..., 147, 150, 145], [169, 170, 171, ..., 147, 147, 143]], ..., [[143, 142, 145, ..., 46, 63, 62], [144, 147, 133, ..., 60, 66, 63], [136, 143, 125, ..., 65, 66, 63], ..., [ 36, 49, 36, ..., 127, 129, 129], [ 39, 104, 38, ..., 128, 130, 130], [ 42, 119, 38, ..., 128, 127, 128]], [[150, 150, 136, ..., 50, 62, 62], [151, 150, 138, ..., 59, 63, 63], [148, 149, 130, ..., 62, 66, 62], ..., [ 34, 74, 35, ..., 129, 127, 127], [ 38, 116, 36, ..., 127, 131, 127], [ 49, 121, 39, ..., 127, 127, 127]], [[154, 152, 143, ..., 55, 59, 60], [156, 151, 142, ..., 57, 64, 62], [154, 155, 137, ..., 62, 66, 62], ..., [ 37, 112, 37, ..., 129, 129, 127], [ 44, 117, 38, ..., 127, 131, 129], [ 80, 122, 55, ..., 126, 128, 127]]], dtype=uint8), array([[[173, 169, 170, ..., 173, 173, 172], [171, 170, 171, ..., 172, 171, 172], [171, 170, 171, ..., 173, 171, 171], ..., [173, 172, 173, ..., 161, 161, 168], [173, 173, 173, ..., 164, 164, 168], [174, 172, 173, ..., 166, 164, 171]], [[171, 171, 169, ..., 173, 173, 173], [169, 171, 169, ..., 173, 172, 173], [171, 170, 169, ..., 173, 172, 170], ..., [173, 173, 172, ..., 151, 158, 169], [172, 173, 172, ..., 166, 165, 170], [173, 173, 171, ..., 168, 164, 173]], [[169, 169, 170, ..., 169, 172, 173], [170, 169, 171, ..., 172, 171, 173], [168, 170, 171, ..., 172, 170, 173], ..., [172, 173, 173, ..., 142, 155, 169], [172, 173, 174, ..., 161, 163, 169], [172, 173, 173, ..., 165, 163, 172]], ..., [[156, 159, 159, ..., 117, 117, 117], [156, 161, 158, ..., 119, 117, 117], [155, 159, 157, ..., 117, 117, 117], ..., [117, 117, 117, ..., 159, 163, 168], [117, 120, 116, ..., 164, 164, 162], [117, 117, 117, ..., 159, 164, 150]], [[158, 159, 158, ..., 117, 119, 117], [156, 159, 158, ..., 117, 117, 117], [155, 158, 158, ..., 117, 117, 117], ..., [117, 117, 117, ..., 159, 162, 165], [117, 117, 117, ..., 161, 164, 157], [117, 119, 117, ..., 161, 165, 147]], [[157, 158, 159, ..., 117, 116, 116], [155, 158, 159, ..., 117, 114, 117], [155, 156, 159, ..., 120, 116, 117], ..., [117, 117, 117, ..., 164, 159, 163], [117, 117, 119, ..., 163, 164, 155], [117, 117, 117, ..., 163, 164, 144]]], dtype=uint8), array([[[148, 149, 147, ..., 149, 1, 150], [145, 149, 145, ..., 151, 1, 152], [147, 149, 145, ..., 152, 1, 150], ..., [172, 173, 173, ..., 173, 1, 173], [173, 173, 173, ..., 173, 1, 172], [173, 172, 172, ..., 173, 1, 173]], [[144, 145, 145, ..., 150, 150, 151], [145, 145, 145, ..., 155, 150, 151], [147, 145, 144, ..., 154, 154, 151], ..., [170, 171, 173, ..., 173, 173, 174], [170, 170, 173, ..., 176, 173, 173], [171, 169, 169, ..., 173, 173, 173]], [[144, 147, 145, ..., 151, 150, 150], [145, 145, 145, ..., 150, 150, 149], [147, 147, 147, ..., 151, 155, 149], ..., [169, 170, 170, ..., 172, 173, 173], [172, 170, 170, ..., 173, 174, 171], [172, 169, 169, ..., 173, 173, 173]], ..., [[133, 131, 133, ..., 137, 136, 135], [134, 134, 133, ..., 141, 136, 135], [133, 133, 133, ..., 137, 136, 135], ..., [120, 77, 105, ..., 150, 147, 143], [122, 98, 120, ..., 150, 149, 138], [120, 107, 121, ..., 149, 147, 136]], [[133, 133, 131, ..., 136, 136, 136], [133, 131, 131, ..., 137, 136, 135], [131, 133, 133, ..., 136, 136, 134], ..., [117, 95, 108, ..., 148, 147, 140], [119, 114, 119, ..., 150, 150, 135], [120, 119, 119, ..., 145, 145, 134]], [[131, 134, 131, ..., 136, 133, 134], [131, 135, 130, ..., 136, 135, 133], [133, 136, 131, ..., 136, 134, 131], ..., [119, 114, 117, ..., 145, 145, 136], [117, 117, 119, ..., 145, 147, 133], [117, 120, 117, ..., 145, 145, 131]]], dtype=uint8), array([[[172, 156, 166, ..., 174, 173, 176], [174, 173, 173, ..., 177, 173, 176], [173, 172, 173, ..., 176, 174, 176], ..., [169, 171, 170, ..., 168, 169, 169], [169, 169, 169, ..., 172, 168, 171], [170, 169, 169, ..., 170, 170, 171]], [[174, 155, 166, ..., 173, 174, 174], [174, 172, 174, ..., 173, 176, 176], [174, 174, 174, ..., 173, 176, 173], ..., [171, 171, 172, ..., 168, 171, 169], [172, 171, 172, ..., 169, 169, 169], [170, 171, 169, ..., 169, 170, 169]], [[173, 154, 166, ..., 172, 177, 173], [174, 173, 174, ..., 174, 174, 173], [173, 174, 173, ..., 173, 177, 173], ..., [170, 168, 169, ..., 166, 169, 165], [170, 169, 170, ..., 169, 169, 169], [169, 168, 169, ..., 168, 171, 169]], ..., [[137, 136, 136, ..., 136, 136, 135], [136, 136, 136, ..., 136, 136, 135], [138, 136, 136, ..., 137, 136, 135], ..., [ 48, 50, 73, ..., 127, 131, 120], [ 58, 66, 116, ..., 127, 127, 127], [ 77, 85, 122, ..., 126, 127, 123]], [[136, 136, 136, ..., 135, 133, 135], [136, 136, 136, ..., 136, 136, 134], [136, 135, 136, ..., 136, 134, 134], ..., [ 62, 57, 115, ..., 127, 130, 116], [ 98, 97, 123, ..., 125, 126, 127], [115, 116, 126, ..., 123, 126, 122]], [[135, 135, 135, ..., 135, 134, 135], [136, 136, 135, ..., 135, 135, 134], [136, 136, 135, ..., 135, 135, 134], ..., [ 72, 85, 117, ..., 123, 122, 108], [107, 120, 122, ..., 123, 125, 122], [117, 123, 122, ..., 122, 126, 128]]], dtype=uint8), array([[[172, 173, 169, ..., 169, 170, 168], [170, 172, 168, ..., 169, 171, 171], [169, 170, 169, ..., 173, 169, 170], ..., [171, 170, 169, ..., 165, 163, 164], [170, 170, 169, ..., 164, 164, 163], [169, 169, 168, ..., 168, 164, 166]], [[171, 170, 171, ..., 168, 172, 170], [169, 169, 169, ..., 169, 172, 171], [169, 170, 169, ..., 169, 171, 170], ..., [169, 168, 171, ..., 164, 165, 166], [169, 169, 169, ..., 163, 164, 165], [169, 168, 169, ..., 164, 164, 169]], [[172, 171, 171, ..., 172, 168, 168], [172, 172, 168, ..., 172, 169, 170], [171, 171, 168, ..., 172, 169, 169], ..., [169, 169, 169, ..., 164, 166, 166], [170, 169, 170, ..., 164, 166, 165], [170, 169, 169, ..., 165, 164, 165]], ..., [[141, 135, 131, ..., 126, 126, 127], [159, 143, 131, ..., 126, 126, 127], [159, 162, 136, ..., 127, 126, 127], ..., [115, 114, 113, ..., 112, 112, 113], [115, 115, 113, ..., 108, 111, 112], [117, 117, 113, ..., 112, 108, 113]], [[137, 134, 133, ..., 126, 127, 125], [159, 138, 131, ..., 127, 127, 127], [161, 159, 136, ..., 128, 127, 127], ..., [115, 114, 113, ..., 113, 113, 113], [113, 115, 113, ..., 113, 113, 108], [114, 115, 113, ..., 112, 113, 112]], [[136, 133, 131, ..., 126, 127, 125], [159, 137, 131, ..., 128, 127, 127], [163, 158, 136, ..., 127, 127, 127], ..., [114, 113, 115, ..., 113, 112, 112], [116, 114, 113, ..., 112, 113, 108], [115, 113, 113, ..., 113, 113, 109]]], dtype=uint8), array([[[164, 164, 164, ..., 39, 49, 15], [164, 164, 164, ..., 43, 49, 28], [165, 164, 164, ..., 43, 45, 17], ..., [166, 164, 154, ..., 115, 113, 15], [165, 164, 163, ..., 144, 141, 20], [165, 164, 164, ..., 161, 159, 20]], [[164, 163, 163, ..., 38, 46, 15], [164, 163, 163, ..., 39, 45, 28], [164, 164, 161, ..., 43, 45, 24], ..., [164, 159, 151, ..., 115, 111, 20], [165, 163, 159, ..., 144, 141, 21], [164, 161, 162, ..., 159, 158, 18]], [[164, 164, 164, ..., 44, 43, 11], [164, 164, 165, ..., 42, 43, 18], [164, 163, 164, ..., 41, 42, 29], ..., [164, 157, 128, ..., 130, 108, 18], [164, 164, 149, ..., 151, 137, 20], [166, 164, 155, ..., 158, 157, 21]], ..., [[154, 150, 150, ..., 155, 155, 93], [152, 150, 150, ..., 155, 154, 93], [151, 149, 149, ..., 155, 152, 93], ..., [115, 117, 117, ..., 113, 113, 101], [115, 116, 117, ..., 115, 113, 91], [116, 115, 114, ..., 113, 113, 89]], [[152, 149, 149, ..., 155, 152, 93], [150, 149, 150, ..., 155, 152, 92], [150, 147, 147, ..., 154, 152, 92], ..., [112, 116, 116, ..., 113, 114, 104], [113, 116, 115, ..., 113, 113, 85], [115, 113, 114, ..., 113, 114, 81]], [[150, 148, 149, ..., 155, 155, 93], [150, 147, 149, ..., 155, 154, 92], [150, 145, 147, ..., 155, 154, 94], ..., [113, 114, 116, ..., 113, 115, 106], [113, 116, 116, ..., 114, 114, 87], [114, 115, 115, ..., 113, 115, 86]]], dtype=uint8)]], dtype=object), 'dirnames': array([[array(['1a'], dtype='<U2'), array(['1b'], dtype='<U2'), array(['1c'], dtype='<U2'), array(['1d'], dtype='<U2'), array(['1e'], dtype='<U2'), array(['1f'], dtype='<U2'), array(['1g'], dtype='<U2'), array(['1h'], dtype='<U2'), array(['1i'], dtype='<U2'), array(['1j'], dtype='<U2'), array(['1k'], dtype='<U2'), array(['1l'], dtype='<U2'), array(['1m'], dtype='<U2'), array(['1n'], dtype='<U2'), array(['1o'], dtype='<U2'), array(['1p'], dtype='<U2'), array(['1q'], dtype='<U2'), array(['1r'], dtype='<U2'), array(['1s'], dtype='<U2'), array(['1t'], dtype='<U2')]], dtype=object)}
In [403]:
data, label = [], []
for idx, person in enumerate(face_data):
for i in range(person.shape[2]):
img = person[:,:,i]
img = img.reshape(10304)
data.append(img)
label.append(idx)
data, label = np.array(data), np.array(label) # LABEL: Range of label values: 0-19. When using the for-loop and print(f'Label: {label}'), each picture is tagged with the correct label per-person
print(f'data shape: {data.shape}, label shape: {label.shape}, unique people ids: {np.unique(label)}')
data shape: (575, 10304), label shape: (575,), unique people ids: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
In [404]:
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(10, 10))
for i in range(0,20):
fig.add_subplot(4,5,i+1)
plt.imshow(data[i].reshape(112,92), cmap='gray')
plt.xticks([])
plt.yticks([])
person_id_label = f'Person ID: {label[i]}'
plt.title(person_id_label).set_color('green')
In [405]:
##
In [406]:
x_train, x_test, y_train, y_test = train_test_split(data, label, test_size=0.2, random_state=random_state, stratify=label)
x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.25, random_state=random_state, stratify=y_train)
print(f'x_train: {x_train.shape}\ny_train: {y_train.shape}\nx_test: {x_test.shape}\ny_test: {y_test.shape}\nx_val: {x_val.shape}\ny_val: {y_val.shape}')
x_train: (345, 10304) y_train: (345,) x_test: (115, 10304) y_test: (115,) x_val: (115, 10304) y_val: (115,)
In [407]:
# Standard scaling
scaler = StandardScaler()
x_train = scaler.fit_transform(x_train)
x_test = scaler.transform(x_test)
x_val = scaler.transform(x_val)
y_train = torch.eye(20)[torch.LongTensor(y_train)]
y_test = torch.eye(20)[torch.LongTensor(y_test)]
y_val = torch.eye(20)[torch.LongTensor(y_val)]
print(f'x_train: {x_train.shape}\ny_train: {y_train.shape}\nx_test: {x_test.shape}\ny_test: {y_test.shape}\nx_val: {x_val.shape}\ny_val: {y_val.shape}')
x_train: (345, 10304) y_train: torch.Size([345, 20]) x_test: (115, 10304) y_test: torch.Size([115, 20]) x_val: (115, 10304) y_val: torch.Size([115, 20])
In [408]:
pca = PCA(n_components=10, random_state=random_state)
x_train_pca = pca.fit_transform(x_train)
print(f'Explained variance: {sum(pca.explained_variance_ratio_)}')
x_test_pca = pca.transform(x_test)
x_val_pca = pca.transform(x_val)
Explained variance: 0.6492373520222806
In [409]:
print(f'x_train_pca: {x_train_pca.shape}\nx_test_pca: {x_test_pca.shape}\nx_val_pca: {x_val_pca.shape}')
x_train_pca: (345, 10) x_test_pca: (115, 10) x_val_pca: (115, 10)
In [410]:
face_recovered = pca.inverse_transform(x_train_pca) # Face images in the stratified training set, 0 to 344
plt.gray()
plt.xticks([])
plt.yticks([])
plt.title('Sample Image after PCA').set_color('green')
plt.imshow(np.array(face_recovered[111]).reshape(112,92))
Out[410]:
<matplotlib.image.AxesImage at 0x277c53853a0>
In [411]:
kmeans = KMeans(n_clusters=20, random_state=random_state)
x_train_kmeans = kmeans.fit_transform(x_train_pca)
x_test_kmeans = kmeans.transform(x_test_pca)
x_val_kmeans = kmeans.transform(x_val_pca)
x_train_tensor = torch.FloatTensor(x_train_kmeans)
x_test_tensor = torch.FloatTensor(x_test_kmeans)
x_val_tensor = torch.FloatTensor(x_val_kmeans)
D:\anaconda3\Lib\site-packages\sklearn\cluster\_kmeans.py:1429: UserWarning: KMeans is known to have a memory leak on Windows with MKL, when there are less chunks than available threads. You can avoid it by setting the environment variable OMP_NUM_THREADS=2. warnings.warn(
In [426]:
x_train_kmeans
Out[426]:
array([[ 85.34897988, 75.84730645, 76.09119472, ..., 97.77866986, 83.10348576, 51.85576079], [ 77.5241804 , 131.31719569, 105.1927116 , ..., 108.86916624, 93.96202626, 121.07083419], [ 98.44347613, 104.5782617 , 84.55857067, ..., 98.99388958, 75.12079416, 88.7045111 ], ..., [118.08961643, 35.20255626, 94.85801187, ..., 77.82242491, 120.16538824, 94.55763113], [105.803014 , 92.04468724, 91.22768778, ..., 105.86139586, 83.12872025, 41.19215493], [ 91.05638838, 99.60399459, 76.91126016, ..., 99.80034206, 87.99432228, 91.87189211]])
In [412]:
print(f'x_train_kmeans: {x_train_kmeans.shape}\nx_test_kmeans: {x_test_kmeans.shape}\nx_val_kmeans: {x_val_kmeans.shape}')
x_train_kmeans: (345, 20) x_test_kmeans: (115, 20) x_val_kmeans: (115, 20)
In [413]:
print(f'x_train_kmeans: {x_train_tensor.shape}\nx_test_kmeans: {x_train_tensor.shape}\nx_val_kmeans: {x_train_tensor.shape}')
x_train_kmeans: torch.Size([345, 20]) x_test_kmeans: torch.Size([345, 20]) x_val_kmeans: torch.Size([345, 20])
In [414]:
print(f'x_train_reshaped: {x_train_reshaped.shape}\ny_train_reshaped: {y_train_reshaped.shape}\nx_test_reshaped: {x_test_reshaped.shape}\ny_test_reshaped: {y_test_reshaped.shape}\nx_val_reshaped: {x_val_reshaped.shape}\ny_val_reshaped: {y_val_reshaped.shape}')
x_train_reshaped: (345, 1, 20) y_train_reshaped: torch.Size([345, 1, 20]) x_test_reshaped: (115, 1, 20) y_test_reshaped: torch.Size([115, 1, 20]) x_val_reshaped: (115, 1, 20) y_val_reshaped: torch.Size([115, 1, 20])
In [415]:
class FaceRecognitionModel(nn.Module):
def __init__(self, input_size, num_classes, units, dropout=False):
super(FaceRecognitionModel, self).__init__()
self.layers = nn.Sequential(
nn.Linear(input_size, 34),
nn.Linear(34, units),
nn.ReLU(),
nn.Dropout(0.2) if dropout else nn.Identity(),
nn.Linear(units, units),
nn.ReLU(),
nn.Dropout(0.2) if dropout else nn.Identity(),
nn.Linear(units, num_classes),
nn.Softmax(dim=1)
)
def forward(self, x):
x = x.view(x.size(0), -1) # Flatten the input
return self.layers(x)
In [416]:
def train_model(model, train_loader, val_loader, criterion, optimizer, epochs=100):
best_val_accuracy = 0
best_model = None
for epoch in range(epochs):
model.train()
train_loss = 0
train_correct = 0
total_train = 0
for inputs, labels in train_loader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item()
_, predicted = torch.max(outputs, 1)
_, true_labels = torch.max(labels, 1)
total_train += labels.size(0)
train_correct += (predicted == true_labels).sum().item()
# Validation
model.eval()
val_loss = 0
val_correct = 0
total_val = 0
with torch.no_grad():
for inputs, labels in val_loader:
outputs = model(inputs)
loss = criterion(outputs, labels)
val_loss += loss.item()
_, predicted = torch.max(outputs, 1)
_, true_labels = torch.max(labels, 1)
total_val += labels.size(0)
val_correct += (predicted == true_labels).sum().item()
train_accuracy = 100 * train_correct / total_train
val_accuracy = 100 * val_correct / total_val
if val_accuracy > best_val_accuracy:
best_val_accuracy = val_accuracy
best_model = model.state_dict()
if epoch % 10 == 0:
print(f'Epoch {epoch}: Train Loss {train_loss/len(train_loader):.4f}, '
f'Train Acc {train_accuracy:.2f}%, '
f'Val Loss {val_loss/len(val_loader):.4f}, '
f'Val Acc {val_accuracy:.2f}%')
return best_model
In [417]:
def visualize_predictions(model, x_data, y_data, x_original, title):
model.eval()
plt.figure(figsize=(15, 10))
# Predict for first 20 samples
predictions = []
with torch.no_grad():
for i in range(20):
inputs = x_data[i].unsqueeze(0)
output = model(inputs)
pred = torch.argmax(output).item()
true_label = torch.argmax(y_data[i]).item()
predictions.append(pred)
plt.subplot(4, 5, i+1)
# Reconstruct and show original image
img = x_original[i].reshape(112, 92)
plt.imshow(img, cmap='gray')
plt.title(f'Pred: {pred}, Actual: {true_lab el}')
plt.axis('off')
plt.suptitle(title)
plt.tight_layout()
plt.show()
return predictions
In [418]:
print(f'x_train_kmeans: {x_train_kmeans.shape}\nx_test_kmeans: {x_test_kmeans.shape}\nx_val_kmeans: {x_val_kmeans.shape}')
x_train_kmeans: (345, 20) x_test_kmeans: (115, 20) x_val_kmeans: (115, 20)
In [419]:
print(x_train.shape[1])
10304
In [420]:
print(x_train_tensor.shape)
torch.Size([345, 20])
In [421]:
# Create datasets and loaders
train_dataset = TensorDataset(x_train_tensor, y_train)
val_dataset = TensorDataset(x_val_tensor, y_val)
test_dataset = TensorDataset(x_test_tensor, y_test)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
test_loader = DataLoader(test_dataset, batch_size=32)
# Hyperparameters
model_params = {
'input_size': x_train_tensor.shape[1],
'num_classes': 20,
'units': 128,
'dropout': True
}
# Initialize model
model = FaceRecognitionModel(**model_params)
# Loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# Train the model
best_model_state = train_model(model, train_loader, val_loader, criterion, optimizer)
# Load best model
model.load_state_dict(best_model_state)
# Evaluate on test set
model.eval()
test_correct = 0
total_test = 0
with torch.no_grad():
for inputs, labels in test_loader:
outputs = model(inputs)
_, predicted = torch.max(outputs, 1)
_, true_labels = torch.max(labels, 1)
total_test += labels.size(0)
test_correct += (predicted == true_labels).sum().item()
test_accuracy = 100 * test_correct / total_test
print(f'Test Accuracy: {test_accuracy:.2f}%')
Epoch 0: Train Loss 0.0802, Train Acc 6.96%, Val Loss 0.0650, Val Acc 11.30% Epoch 10: Train Loss 0.0399, Train Acc 32.75%, Val Loss 0.0371, Val Acc 37.39% Epoch 20: Train Loss 0.0255, Train Acc 62.61%, Val Loss 0.0229, Val Acc 66.96% Epoch 30: Train Loss 0.0168, Train Acc 76.52%, Val Loss 0.0128, Val Acc 83.48% Epoch 40: Train Loss 0.0115, Train Acc 84.06%, Val Loss 0.0080, Val Acc 90.43% Epoch 50: Train Loss 0.0102, Train Acc 84.93%, Val Loss 0.0075, Val Acc 90.43% Epoch 60: Train Loss 0.0065, Train Acc 91.30%, Val Loss 0.0041, Val Acc 94.78% Epoch 70: Train Loss 0.0060, Train Acc 92.75%, Val Loss 0.0037, Val Acc 94.78% Epoch 80: Train Loss 0.0055, Train Acc 91.59%, Val Loss 0.0043, Val Acc 93.91% Epoch 90: Train Loss 0.0040, Train Acc 95.07%, Val Loss 0.0022, Val Acc 97.39% Test Accuracy: 94.78%
In [422]:
# Visualize predictions
print("Visualizing Training Set Predictions")
train_predictions = visualize_predictions(model, x_train_tensor, y_train, x_train_orig, 'Training Set Predictions')
Visualizing Training Set Predictions
In [423]:
print("Visualizing Test Set Predictions")
test_predictions = visualize_predictions(model, x_test_tensor, y_test, x_test_orig, 'Test Set Predictions')
Visualizing Test Set Predictions
In [424]:
print("Visualizing Validation Set Predictions")
val_predictions = visualize_predictions(model, x_val_tensor, y_val, x_val_orig, 'Validation Set Predictions')
Visualizing Validation Set Predictions
In [425]:
torch.save(model.state_dict(), 'face_recognition_model.pth')
In [ ]: