Problem: The research project aimed to explore the correlation between features extracted from a Vision Transformer (ViT) model trained to classify rock images and human perceptual features provided in the MDS_120 dataset. Specifically, the focus was on understanding how the weights of the linear layer with 8 features in the ViT model relate to human cognitive perceptions.
Context: The project involved classifying rocks into three categories using a neural network model. The subsequent analysis aimed to bridge the gap between machine-based rock classification and human perceptual understanding.
This research project is built on a series of other projects to work with modeling human cognition using neural networks. The results of this project are used to study human perceptual behavior in neural networks.
Activities:
Role: One of the researchers in the project, responsible for the design, implementation, and analysis.
Approach:
A Python code snipped for preprocessing the datasets after splitting them is shown below:
# split up training into training and validation
****splits = rock_ds.train_test_split(test_size=0.333)
train_ds = splits['train']
val_ds = splits['test']
****
from torchvision.transforms import (
CenterCrop,
Compose,
Normalize,
RandomHorizontalFlip,
RandomResizedCrop,
Resize,
ToTensor,
)
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
if "height" in image_processor.size:
size = (image_processor.size["height"], image_processor.size["width"])
crop_size = size
max_size = None
elif "shortest_edge" in image_processor.size:
size = image_processor.size["shortest_edge"]
crop_size = (size, size)
max_size = image_processor.size.get("longest_edge")
train_transforms = Compose(
[
RandomResizedCrop(crop_size),
RandomHorizontalFlip(),
ToTensor(),
normalize,
]
)
val_transforms = Compose(
[
Resize(size),
CenterCrop(crop_size),
ToTensor(),
normalize,
]
)
def preprocess_train(example_batch):
"""Apply train_transforms across a batch."""
example_batch["pixel_values"] = [
train_transforms(image.convert("RGB")) for image in example_batch["img"]
]
return example_batch
def preprocess_val(example_batch):
"""Apply val_transforms across a batch."""
example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["img"]]
return example_batch
train_ds.set_transform(preprocess_train)
val_ds.set_transform(preprocess_val)
test_ds.set_transform(preprocess_val)