Fine-tuning Vision transformer for Rock classification

Situation

Problem: The research project aimed to explore the correlation between features extracted from a Vision Transformer (ViT) model trained to classify rock images and human perceptual features provided in the MDS_120 dataset. Specifically, the focus was on understanding how the weights of the linear layer with 8 features in the ViT model relate to human cognitive perceptions.

Context: The project involved classifying rocks into three categories using a neural network model. The subsequent analysis aimed to bridge the gap between machine-based rock classification and human perceptual understanding.

This research project is built on a series of other projects to work with modeling human cognition using neural networks. The results of this project are used to study human perceptual behavior in neural networks.

Task

Activities:

  1. Trained a Vision Transformer model on a dataset of rock images, comprising 360 training images and 120 testing images.
  2. Split the training data to reserve 120 images for validation.
  3. Fine-tuned the model using the rock data.
  4. Extracted the weights from the linear layer with 8 features in the pre-final layers of the ViT model.
  5. Utilized Procrustes analysis to compare these weights with the human perceptual features in the MDS_120 dataset.

Role: One of the researchers in the project, responsible for the design, implementation, and analysis.

Action

Approach:

  1. Data Preparation: Organized rock image data into training, validation, and testing sets.

A Python code snipped for preprocessing the datasets after splitting them is shown below:

# split up training into training and validation
****splits = rock_ds.train_test_split(test_size=0.333)
train_ds = splits['train']
val_ds = splits['test']
****
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    ToTensor,
)

normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
if "height" in image_processor.size:
    size = (image_processor.size["height"], image_processor.size["width"])
    crop_size = size
    max_size = None
elif "shortest_edge" in image_processor.size:
    size = image_processor.size["shortest_edge"]
    crop_size = (size, size)
    max_size = image_processor.size.get("longest_edge")

train_transforms = Compose(
        [
            RandomResizedCrop(crop_size),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize,
        ]
    )

val_transforms = Compose(
        [
            Resize(size),
            CenterCrop(crop_size),
            ToTensor(),
            normalize,
        ]
    )

def preprocess_train(example_batch):
    """Apply train_transforms across a batch."""
    example_batch["pixel_values"] = [
        train_transforms(image.convert("RGB")) for image in example_batch["img"]
    ]
    return example_batch

def preprocess_val(example_batch):
    """Apply val_transforms across a batch."""
    example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["img"]]
    return example_batch
train_ds.set_transform(preprocess_train)
val_ds.set_transform(preprocess_val)
test_ds.set_transform(preprocess_val)
  1. Model Training: Implemented and trained a Vision Transformer model using the training data.
  2. Fine-Tuning: Adjusted the model on the rock data for improved performance.