Image and Text Analysis using Multi-modal Embeddings

Last Update: 22 Mar 2024

Description: Workshop materials for the Image and Text Topic Modelling using Multi-modal Embeddings. Designed to work with Google Colab. Check the full version on https://github.com/justinchuntingho/ImageTextAnalysisWorkshop

The Workflow

Setting Up

!pip install bertopic
import zipfile
import urllib
import numpy as np
import pandas as pd
from tqdm import tqdm
from PIL import Image
from sentence_transformers import SentenceTransformer, util
from bertopic import BERTopic
import base64
from io import BytesIO
from IPython.display import HTML

Getting Data

zip_path, _ = urllib.request.urlretrieve("https://github.com/justinchuntingho/ImageTextAnalysisWorkshop/raw/main/data.zip")
with zipfile.ZipFile(zip_path, "r") as f:
    f.extractall("./")
df = pd.read_csv('data/data.csv')
df

Model Training

True MultiModal

def get_concat_h_multi_resize(im_list, resample=Image.BICUBIC):
    min_height = min(im.height for im in im_list)
    im_list_resize = [im.resize((int(im.width * min_height / im.height), min_height),resample=resample)
                      for im in im_list]
    total_width = sum(im.width for im in im_list_resize)
    dst = Image.new('RGB', (total_width, min_height))
    pos_x = 0
    for im in im_list_resize:
        dst.paste(im, (pos_x, 0))
        pos_x += im.width
    return dst

def get_concat_v_multi_resize(im_list, resample=Image.BICUBIC):
    min_width = min(im.width for im in im_list)
    im_list_resize = [im.resize((min_width, int(im.height * min_width / im.width)),resample=resample)
                      for im in im_list]
    total_height = sum(im.height for im in im_list_resize)
    dst = Image.new('RGB', (min_width, total_height))
    pos_y = 0
    for im in im_list_resize:
        dst.paste(im, (0, pos_y))
        pos_y += im.height
    return dst

def get_concat_tile_resize(im_list_2d, resample=Image.BICUBIC):
    im_list_v = [get_concat_h_multi_resize(im_list_h, resample=resample) for im_list_h in im_list_2d]
    return get_concat_v_multi_resize(im_list_v, resample=resample)

def get_top_imgs(topic):
    top_imgs = probs_df[topic].nlargest(9).index
    im1 = Image.open(df['image_path'][top_imgs[0]])
    im2 = Image.open(df['image_path'][top_imgs[1]])
    im3 = Image.open(df['image_path'][top_imgs[2]])
    im4 = Image.open(df['image_path'][top_imgs[3]])
    im5 = Image.open(df['image_path'][top_imgs[4]])
    im6 = Image.open(df['image_path'][top_imgs[5]])
    im7 = Image.open(df['image_path'][top_imgs[6]])
    im8 = Image.open(df['image_path'][top_imgs[7]])
    im9 = Image.open(df['image_path'][top_imgs[8]])
    return get_concat_tile_resize([[im1, im2, im3],
                                    [im4, im5, im6],
                                    [im7, im8, im9]])

def image_base64(im):
    with BytesIO() as buffer:
        im.resize((600,600)).save(buffer, 'jpeg')
        return base64.b64encode(buffer.getvalue()).decode()

def image_formatter(im):
    return f'<img src="data:image/jpeg;base64,{image_base64(im)}">'

def truncate_sentence(sentence, tokenizer):
    cur_sentence = sentence
    tokens = tokenizer.encode(cur_sentence)
    if len(tokens) > 77:
        truncated_tokens = tokens[1:76]
        cur_sentence = tokenizer.decode(truncated_tokens)
        return truncate_sentence(cur_sentence, tokenizer)
    else:
        return cur_sentence
from bertopic.backend import MultiModalBackend
from transformers import CLIPTokenizer
model = MultiModalBackend('clip-ViT-B-32', batch_size=32)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")

docs = [truncate_sentence(x,tokenizer) for x in df['text'].tolist()]
images = df['image_path'].tolist()
docs[0:6]
images[0:6]
# Embed both images and documents, then average them
doc_image_embeddings = model.embed(docs, images)
topic_model = BERTopic(calculate_probabilities=True,
                       n_gram_range=(1,2),
                       min_topic_size=5, # Setting this based on the smallest category in GS
                       verbose=True)

topics, probs = topic_model.fit_transform(docs, doc_image_embeddings)
topic_model.get_topic_info()
df['topic'] = topics
probs_df = pd.DataFrame(probs)
# Extract dataframe
topic_info = topic_model.get_topic_info().drop("Representative_Docs", axis=1).drop("Name", axis=1).drop(index=0)
topic_info['Visual'] = [get_top_imgs(x) for x in topic_info.Topic]
HTML(topic_info.to_html(formatters={'Visual': image_formatter}, escape=False,index=False))
with open('multimodal.html', 'w') as fo:
    fo.write(topic_info.to_html(formatters={'Visual': image_formatter}, escape=False,index=False))

Image Only

def image_base64(im):
    if isinstance(im, str):
        im = get_thumbnail(im)
    with BytesIO() as buffer:
        im.save(buffer, 'jpeg')
        return base64.b64encode(buffer.getvalue()).decode()
from bertopic.representation import KeyBERTInspired, VisualRepresentation
from bertopic.backend import MultiModalBackend

# Image embedding model
embedding_model = MultiModalBackend('clip-ViT-B-32', batch_size=32)

# Image to text representation model
representation_model = {
    "Visual_Aspect": VisualRepresentation(image_to_text_model="nlpconnect/vit-gpt2-image-captioning")
}
# Train our model with images only
topic_model = BERTopic(embedding_model=embedding_model,
                       representation_model=representation_model,
                       min_topic_size=5,
                       calculate_probabilities=True)
topics, probs = topic_model.fit_transform(documents=None, images=df.image_path.to_list())
df['topic'] = topics
probs_df = pd.DataFrame(probs)
# Extract dataframe
topic_info = topic_model.get_topic_info().drop("Representative_Docs", axis=1).drop("Name", axis=1).drop(index=0)
HTML(topic_info.to_html(formatters={'Visual_Aspect': image_formatter}, escape=False,index=False))
with open('img.html', 'w') as fo:
    fo.write(topic_info.to_html(formatters={'Visual_Aspect': image_formatter}, escape=False, index=False))

Covert to Text

from transformers import pipeline
image_to_text = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
df['generated_text'] = [image_to_text(x)[0]['generated_text'] for x in df.image_path]
docs = df.generated_text + df.text
# Train our model with text only
topic_model = BERTopic(embedding_model=SentenceTransformer("all-MiniLM-L6-v2"),
                       n_gram_range=(1,2),
                       min_topic_size=5,
                       calculate_probabilities=True)
topics, probs = topic_model.fit_transform(documents=docs)
df['topic'] = topics
probs_df = pd.DataFrame(probs)
# Extract dataframe
topic_info = topic_model.get_topic_info().drop("Representative_Docs", axis=1).drop("Name", axis=1).drop(index=0)
HTML(topic_info.to_html(escape=False,index=False))
with open('text.html', 'w') as fo:
    fo.write(topic_info.to_html(escape=False,index=False))