<?xml version="1.0" encoding="utf-8" standalone="yes"?>
<rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom" xmlns:content="http://purl.org/rss/1.0/modules/content/">
  <channel>
    <title>Machine Learning on Max Woolf&#39;s Blog</title>
    <link>https://minimaxir.com/tag/machine-learning/</link>
    <description>Recent content in Machine Learning on Max Woolf&#39;s Blog</description>
    <image>
      <title>Max Woolf&#39;s Blog</title>
      <url>https://minimaxir.com/android-chrome-512x512.png</url>
      <link>https://minimaxir.com/android-chrome-512x512.png</link>
    </image>
    <generator>Hugo</generator>
    <language>en</language>
    <copyright>Copyright Max Woolf © 2026</copyright>
    <lastBuildDate>Mon, 30 Jun 2025 10:00:00 -0700</lastBuildDate>
    <atom:link href="https://minimaxir.com/tag/machine-learning/index.xml" rel="self" type="application/rss+xml" />
    <item>
      <title>Predicting Average IMDb Movie Ratings Using Text Embeddings of Movie Metadata</title>
      <link>https://minimaxir.com/2025/06/movie-embeddings/</link>
      <pubDate>Mon, 30 Jun 2025 10:00:00 -0700</pubDate>
      <guid>https://minimaxir.com/2025/06/movie-embeddings/</guid>
      <description>Don&amp;rsquo;t try this in your data science interviews.</description>
      <content:encoded><![CDATA[<p>Months ago, I saw a post titled &ldquo;<a href="https://www.reddit.com/r/datascience/comments/1eykil7/rejected_from_ds_role_with_no_feedback/">Rejected from DS Role with no feedback</a>&rdquo; on Reddit&rsquo;s <a href="https://www.reddit.com/r/datascience/">Data Science subreddit</a>, in which a prospective job candidate for a data science position provided a <a href="https://colab.research.google.com/drive/1Ud2tXW2IAw_dXA5DONvNpPmmlL1foSwK">Colab Notebook</a> documenting their submission for a take-home assignment and asking for feedback as to why they were rejected. Per the Reddit user, the assignment was:</p>
<blockquote>
<p>Use the publicly available <a href="https://developer.imdb.com/non-commercial-datasets/">IMDB Datasets</a> to build a model that predicts a movie&rsquo;s average rating. Please document your approach and present your results in the notebook. Make sure your code is well-organized so that we can follow your modeling process.</p>
</blockquote>
<p><a href="https://www.imdb.com/">IMDb</a>, the Internet Movie Database owned by Amazon, allows users to rate movies on a scale from 1 to 10, wherein the average rating is then displayed prominently on the movie&rsquo;s page:</p>
<figure>

    <img loading="lazy" srcset="/2025/06/movie-embeddings/shawshank_hu_fe8025c2c6a0fa89.webp 320w,/2025/06/movie-embeddings/shawshank_hu_f0b2bc74865ccb73.webp 768w,/2025/06/movie-embeddings/shawshank_hu_8f544060412f7f54.webp 1024w,/2025/06/movie-embeddings/shawshank.webp 1082w" src="shawshank.webp"
         alt="The Shawshank Redemption is currently the highest-rated movie on IMDb with an average rating of 9.3 derived from 3.1 million user votes."/> <figcaption>
            <p><a href="https://www.imdb.com/title/tt0111161/?ref_=sr_t_1">The Shawshank Redemption</a> is currently the <a href="https://www.imdb.com/search/title/?groups=top_100&amp;sort=user_rating,desc">highest-rated movie on IMDb</a> with an average rating of 9.3 derived from 3.1 million user votes.</p>
        </figcaption>
</figure>

<p>In their notebook, the Redditor identifies a few intuitive features for such a model, including the year in which the movie was released, the genre(s) of the movies, and the actors/directors of the movie. However, the model they built is a <a href="https://www.tensorflow.org/">TensorFlow</a> and <a href="https://keras.io/">Keras</a>-based neural network, with all the bells-and-whistles such as <a href="https://en.wikipedia.org/wiki/Batch_normalization">batch normalization</a> and <a href="https://en.wikipedia.org/wiki/Dilution_%28neural_networks%29">dropout</a>. The immediate response by other data scientists on /r/datascience was, at its most polite, &ldquo;why did you use a neural network when it&rsquo;s a <a href="https://en.wikipedia.org/wiki/Black_box">black box</a> that you can&rsquo;t explain?&rdquo;</p>
<p>Reading those replies made me nostalgic. Way back in 2017, before my first job as a data scientist, neural networks using frameworks such as TensorFlow and Keras were all the rage for their ability to &ldquo;<a href="https://en.wikipedia.org/wiki/Universal_approximation_theorem">solve any problem</a>&rdquo; but were often seen as lazy and unskilled compared to traditional statistical modeling such as ordinary least squares linear regression or even gradient boosted trees. Although it&rsquo;s funny to see that perception against neural networks in the data science community hasn&rsquo;t changed since, nowadays the black box nature of neural networks can be an acceptable business tradeoff if the prediction results are higher quality and interpretability is not required.</p>
<p>Looking back at the assignment description, the objective is only &ldquo;predict a movie&rsquo;s average rating.&rdquo; For data science interview take-homes, this is unusual: those assignments typically have an extra instruction along the lines of &ldquo;explain your model and what decisions stakeholders should make as a result of it&rdquo;, which is a strong hint that you need to use an explainable model like linear regression to obtain feature coefficients, or even a middle-ground like gradient boosted trees and its <a href="https://stats.stackexchange.com/questions/332960/what-is-variable-importance">variable importance</a> to quantify relative feature contribution to the model. <sup id="fnref:1"><a href="#fn:1" class="footnote-ref" role="doc-noteref">1</a></sup> In absence of that particular constraint, it&rsquo;s arguable that anything goes, including neural networks.</p>
<p>The quality of neural networks have improved significantly since 2017, even moreso due to the massive rise of LLMs. Why not try just feeding a LLM all raw metadata for a movie and encode it into a text embedding and build a statistical model based off of that? Would a neural network do better than a traditional statistical model in that instance? Let&rsquo;s find out!</p>
<h2 id="about-imdb-data">About IMDb Data</h2>
<p>The <a href="https://developer.imdb.com/non-commercial-datasets/">IMDb Non-Commercial Datasets</a> are famous sets of data that have been around for nearly a decade <sup id="fnref:2"><a href="#fn:2" class="footnote-ref" role="doc-noteref">2</a></sup> but are still updated daily. Back in 2018 as a budding data scientist, I performed a <a href="https://minimaxir.com/2018/07/imdb-data-analysis/">fun exporatory data analysis</a> using these datasets, although the results aren&rsquo;t too surprising.</p>
<figure>

    <img loading="lazy" srcset="/2025/06/movie-embeddings/imdb-4_hu_fdf90cbdd2dd2c7e.webp 320w,/2025/06/movie-embeddings/imdb-4_hu_1c45abe215427c09.webp 768w,/2025/06/movie-embeddings/imdb-4_hu_62d0feb034e8b054.webp 1024w,/2025/06/movie-embeddings/imdb-4.png 1200w" src="imdb-4.png"
         alt="The average rating for a movie is around 6 and tends to skew higher: a common trend in internet rating systems."/> <figcaption>
            <p>The average rating for a movie is around 6 and tends to skew higher: a common trend in internet rating systems.</p>
        </figcaption>
</figure>

<p>But in truth, these datasets are a terrible idea for companies to use for a take-home assignment. Although the datasets are released under a non-commercial license, IMDb doesn&rsquo;t want to give too much information to their competitors, which results in a severely limited amount of features that could be used to build a good predictive model. Here are the common movie-performance-related features present in the <code>title.basics.tsv.gz</code> file:</p>
<ul>
<li><strong>tconst</strong>: unique identifier of the title</li>
<li><strong>titleType</strong>: the type/format of the title (e.g. movie, tvmovie, short, tvseries, etc)</li>
<li><strong>primaryTitle</strong>: the more popular title / the title used by the filmmakers on promotional materials at the point of release</li>
<li><strong>isAdult</strong>: 0: non-adult title; 1: adult title</li>
<li><strong>startYear</strong>: represents the release year of a title.</li>
<li><strong>runtimeMinutes</strong>: primary runtime of the title, in minutes</li>
<li><strong>genres</strong>: includes up to three genres associated with the title</li>
</ul>
<p>This is a sensible schema for describing a movie, although it lacks some important information that would be very useful to determine movie quality such as production company, summary blurbs, granular genres/tags, and plot/setting — all of which are available on the IMDb movie page itself and presumably accessible through the <a href="https://developer.imdb.com/documentation/api-documentation/?ref_=/documentation/_PAGE_BODY">paid API</a>. Of note, since the assignment explicitly asks for a <em>movie</em>&rsquo;s average rating, we need to filter the data to only <code>movie</code> and <code>tvMovie</code> entries, which the original assignment failed to do.</p>
<p>The ratings data in <code>title.ratings.tsv.gz</code> is what you&rsquo;d expect:</p>
<ul>
<li><strong>tconst</strong>: unique identifier of the title (which can therefore be mapped to movie metadata using a JOIN)</li>
<li><strong>averageRating</strong>: average of all the individual user ratings</li>
<li><strong>numVotes</strong>: number of votes the title has received</li>
</ul>
<p>In order to ensure that the average ratings for modeling are indeed stable and indicative of user sentiment, I will only analyze movies that have <em>atleast 30 user votes</em>: as of May 10th 2025, that&rsquo;s about 242k movies total. Additionally, I will not use <code>numVotes</code> as a model feature, since that&rsquo;s a metric based more on extrinsic movie popularity rather than the movie itself.</p>
<p>The last major dataset is <code>title.principals.tsv.gz</code>, which has very helpful information on metadata such as the roles people play in the production of a movie:</p>
<ul>
<li><strong>tconst</strong>: unique identifier of the title (which can be mapped to movie data using a JOIN)</li>
<li><strong>nconst</strong>: unique identifier of the principal (this is mapped to <code>name.basics.tsv.gz</code> to get the principal&rsquo;s <code>primaryName</code>, but nothing else useful)</li>
<li><strong>category</strong>: the role the principal served in the title, such as <code>actor</code>, <code>actress</code>, <code>writer</code>, <code>producer</code>, etc.</li>
<li><strong>ordering</strong>: the ordering of the principals within the title, which correlates to the order the principals appear on IMDb&rsquo;s movie cast pages.</li>
</ul>
<p>Additionally, because the datasets are so popular, it&rsquo;s not the first time someone has built a IMDb ratings predictor and it&rsquo;s easy to Google.</p>
<figure>

    <img loading="lazy" srcset="/2025/06/movie-embeddings/google_hu_b09e979836a71049.webp 320w,/2025/06/movie-embeddings/google_hu_c652438955f310d8.webp 768w,/2025/06/movie-embeddings/google.webp 1000w" src="google.webp"/> 
</figure>

<p>Instead of using the official IMDb datasets, these analyses are based on the smaller <a href="https://www.kaggle.com/datasets/carolzhangdc/imdb-5000-movie-dataset/data">IMDB 5000 Movie Dataset</a> hosted on Kaggle, which adds metadata such as movie rating, budget, and further actor metadata that make building a model much easier (albeit &ldquo;number of likes on the lead actor&rsquo;s Facebook page&rdquo; is <em>very</em> extrinsic to movie quality). Using the official datasets with much less metadata is building the models on hard mode and will likely have lower predictive performance.</p>
<p>Although IMDb data is very popular and very well documented, that doesn&rsquo;t mean it&rsquo;s easy to work with.</p>
<h2 id="the-initial-assignment-and-feature-engineering">The Initial Assignment and &ldquo;Feature Engineering&rdquo;</h2>
<p>Data science take-home assignments are typically 1/2 <a href="https://en.wikipedia.org/wiki/Exploratory_data_analysis">exploratory data analysis</a> for identifying impactful dataset features, and 1/2 building, iterating, and explaining the model. For real-world datasets, these are all very difficult problems with many difficult solutions, and the goal from the employer&rsquo;s perspective is seeing more <em>how</em> these problems are solved rather than the actual quantitative results.</p>
<p>The initial Reddit post decided to engineer some expected features using <a href="https://pandas.pydata.org/">pandas</a>, such as <code>is_sequel</code> by checking whether a non-<code>1</code> number is present at the end of a movie title and <a href="https://en.wikipedia.org/wiki/One-hot">one-hot encoding</a> each distinct <code>genre</code> of a movie. These are fine for an initial approach, albeit sequel titles can be idiosyncratic and it suggests that a more <a href="https://www.ibm.com/think/topics/natural-language-processing">NLP</a> approach to identifying sequels and other related media may be useful.</p>
<p>The main trick with this assignment is how to handle the principals. The common data science approach would be to use a sparse binary encoding of the actors/directors/writers, e.g. using a vector where actors present in the movie are <code>1</code> and every other actor is <code>0</code>, which leads to a large number of potential approaches to encode this data performantly, such as scikit-learn&rsquo;s <a href="https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.MultiLabelBinarizer.html">MultiLabelBinarizer</a>. The problem with this approach is that there are a <em>very</em> large number of unique actors / <a href="https://docs.honeycomb.io/get-started/basics/observability/concepts/high-cardinality/">high cardinality</a> — more unique actors than data points themselves — which leads to <a href="https://en.wikipedia.org/wiki/Curse_of_dimensionality">curse of dimensionality</a> issues and workarounds such as encoding only the top <em>N</em> actors will lead to the feature being uninformative since even a generous <em>N</em> will fail to capture the majority of actors.</p>
<figure>

    <img loading="lazy" srcset="/2025/06/movie-embeddings/actor_cum_dist_hu_6b3839329e455b7d.webp 320w,/2025/06/movie-embeddings/actor_cum_dist_hu_b3985aca3321429a.webp 768w,/2025/06/movie-embeddings/actor_cum_dist_hu_27acda9c003abad5.webp 1024w,/2025/06/movie-embeddings/actor_cum_dist.png 1500w" src="actor_cum_dist.png"
         alt="There are actually 624k unique actors in this dataset (Jupyter Notebook), the chart just becomes hard to read at that point."/> <figcaption>
            <p>There are actually 624k unique actors in this dataset (<a href="https://github.com/minimaxir/imdb-embeddings/blob/main/actor_agg.ipynb">Jupyter Notebook</a>), the chart just becomes hard to read at that point.</p>
        </figcaption>
</figure>

<p>Additionally, most statistical modeling approaches cannot account for the <code>ordering</code> of actors as they treat each feature as independent, and since the billing order of actors is generally correlated to their importance in the movie, that&rsquo;s an omission of relevant information to the problem.</p>
<p>These constraints gave me an idea: why not use an LLM to encode <em>all</em> movie data, and build a model using the downstream embedding representation? LLMs have <a href="https://en.wikipedia.org/wiki/Attention_%28machine_learning%29">attention mechanisms</a>, which will not only respect the relative ordering of actors (to give higher predictive priority to higher-billed actors, along with actor cooccurrences), but also identify patterns within movie name texts (to identify sequels and related media semantically).</p>
<p>I started by aggregating and denormalizing all the data locally (<a href="https://github.com/minimaxir/imdb-embeddings/blob/main/imdb_polars_etl_test.ipynb">Jupyter Notebook</a>). Each of the IMDb datasets are hundreds of megabytes and hundreds of thousands of rows at minimum: not quite <a href="https://en.wikipedia.org/wiki/Big_data">big data</a>, but enough to be more cognizant of tooling especially since computationally-intensive JOINs are required. Therefore, I used the <a href="https://pola.rs/">Polars</a> library in Python, which not only loads data super fast, but is also one of the <a href="https://duckdblabs.github.io/db-benchmark/">fastest libraries at performing JOINs</a> and other aggregation tasks. Polars&rsquo;s syntax also allows for some cool tricks: for example, I want to spread out and aggregate the principals (4.1 million rows after prefiltering) for each movie into directors, writers, producers, actors, and all other principals into nested lists while simultaneously having them sorted by <code>ordering</code> as noted above. This is much easier to do in Polars than any other data processing library I&rsquo;ve used, and on millions of rows, this takes <em>less than a second</em>:</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-py3" data-lang="py3"><span class="line"><span class="cl"><span class="n">df_principals_agg</span> <span class="o">=</span> <span class="p">(</span>
</span></span><span class="line"><span class="cl">    <span class="n">df_principals</span><span class="o">.</span><span class="n">sort</span><span class="p">([</span><span class="s2">&#34;tconst&#34;</span><span class="p">,</span> <span class="s2">&#34;ordering&#34;</span><span class="p">])</span>
</span></span><span class="line"><span class="cl">    <span class="o">.</span><span class="n">group_by</span><span class="p">(</span><span class="s2">&#34;tconst&#34;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="o">.</span><span class="n">agg</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">        <span class="n">director_names</span><span class="o">=</span><span class="n">pl</span><span class="o">.</span><span class="n">col</span><span class="p">(</span><span class="s2">&#34;primaryName&#34;</span><span class="p">)</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">pl</span><span class="o">.</span><span class="n">col</span><span class="p">(</span><span class="s2">&#34;category&#34;</span><span class="p">)</span> <span class="o">==</span> <span class="s2">&#34;director&#34;</span><span class="p">),</span>
</span></span><span class="line"><span class="cl">        <span class="n">writer_names</span><span class="o">=</span><span class="n">pl</span><span class="o">.</span><span class="n">col</span><span class="p">(</span><span class="s2">&#34;primaryName&#34;</span><span class="p">)</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">pl</span><span class="o">.</span><span class="n">col</span><span class="p">(</span><span class="s2">&#34;category&#34;</span><span class="p">)</span> <span class="o">==</span> <span class="s2">&#34;writer&#34;</span><span class="p">),</span>
</span></span><span class="line"><span class="cl">        <span class="n">producer_names</span><span class="o">=</span><span class="n">pl</span><span class="o">.</span><span class="n">col</span><span class="p">(</span><span class="s2">&#34;primaryName&#34;</span><span class="p">)</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">pl</span><span class="o">.</span><span class="n">col</span><span class="p">(</span><span class="s2">&#34;category&#34;</span><span class="p">)</span> <span class="o">==</span> <span class="s2">&#34;producer&#34;</span><span class="p">),</span>
</span></span><span class="line"><span class="cl">        <span class="n">actor_names</span><span class="o">=</span><span class="n">pl</span><span class="o">.</span><span class="n">col</span><span class="p">(</span><span class="s2">&#34;primaryName&#34;</span><span class="p">)</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">            <span class="n">pl</span><span class="o">.</span><span class="n">col</span><span class="p">(</span><span class="s2">&#34;category&#34;</span><span class="p">)</span><span class="o">.</span><span class="n">is_in</span><span class="p">([</span><span class="s2">&#34;actor&#34;</span><span class="p">,</span> <span class="s2">&#34;actress&#34;</span><span class="p">])</span>
</span></span><span class="line"><span class="cl">        <span class="p">),</span>
</span></span><span class="line"><span class="cl">        <span class="n">principal_names</span><span class="o">=</span><span class="n">pl</span><span class="o">.</span><span class="n">col</span><span class="p">(</span><span class="s2">&#34;primaryName&#34;</span><span class="p">)</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">            <span class="o">~</span><span class="n">pl</span><span class="o">.</span><span class="n">col</span><span class="p">(</span><span class="s2">&#34;category&#34;</span><span class="p">)</span><span class="o">.</span><span class="n">is_in</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">                <span class="p">[</span><span class="s2">&#34;director&#34;</span><span class="p">,</span> <span class="s2">&#34;writer&#34;</span><span class="p">,</span> <span class="s2">&#34;producer&#34;</span><span class="p">,</span> <span class="s2">&#34;actor&#34;</span><span class="p">,</span> <span class="s2">&#34;actress&#34;</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">            <span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="p">),</span>
</span></span><span class="line"><span class="cl">        <span class="n">principal_roles</span><span class="o">=</span><span class="n">pl</span><span class="o">.</span><span class="n">col</span><span class="p">(</span><span class="s2">&#34;category&#34;</span><span class="p">)</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">            <span class="o">~</span><span class="n">pl</span><span class="o">.</span><span class="n">col</span><span class="p">(</span><span class="s2">&#34;category&#34;</span><span class="p">)</span><span class="o">.</span><span class="n">is_in</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">                <span class="p">[</span><span class="s2">&#34;director&#34;</span><span class="p">,</span> <span class="s2">&#34;writer&#34;</span><span class="p">,</span> <span class="s2">&#34;producer&#34;</span><span class="p">,</span> <span class="s2">&#34;actor&#34;</span><span class="p">,</span> <span class="s2">&#34;actress&#34;</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">            <span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="p">),</span>
</span></span><span class="line"><span class="cl">    <span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="p">)</span>
</span></span></code></pre></div><p>After some cleanup and field renaming, here&rsquo;s an example JSON document for <a href="https://www.imdb.com/title/tt0076759/">Star Wars: Episode IV - A New Hope</a>:</p>
<!-- prettier-ignore-start -->
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-json" data-lang="json"><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;title&#34;</span><span class="p">:</span> <span class="s2">&#34;Star Wars: Episode IV - A New Hope&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;genres&#34;</span><span class="p">:</span> <span class="p">[</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;Action&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;Adventure&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;Fantasy&#34;</span>
</span></span><span class="line"><span class="cl">  <span class="p">],</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;is_adult&#34;</span><span class="p">:</span> <span class="kc">false</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;release_year&#34;</span><span class="p">:</span> <span class="mi">1977</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;runtime_minutes&#34;</span><span class="p">:</span> <span class="mi">121</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;directors&#34;</span><span class="p">:</span> <span class="p">[</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;George Lucas&#34;</span>
</span></span><span class="line"><span class="cl">  <span class="p">],</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;writers&#34;</span><span class="p">:</span> <span class="p">[</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;George Lucas&#34;</span>
</span></span><span class="line"><span class="cl">  <span class="p">],</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;producers&#34;</span><span class="p">:</span> <span class="p">[</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;Gary Kurtz&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;Rick McCallum&#34;</span>
</span></span><span class="line"><span class="cl">  <span class="p">],</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;actors&#34;</span><span class="p">:</span> <span class="p">[</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;Mark Hamill&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;Harrison Ford&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;Carrie Fisher&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;Alec Guinness&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;Peter Cushing&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;Anthony Daniels&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;Kenny Baker&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;Peter Mayhew&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;David Prowse&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;Phil Brown&#34;</span>
</span></span><span class="line"><span class="cl">  <span class="p">],</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;principals&#34;</span><span class="p">:</span> <span class="p">[</span>
</span></span><span class="line"><span class="cl">    <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;John Williams&#34;</span><span class="p">:</span> <span class="s2">&#34;composer&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="p">},</span>
</span></span><span class="line"><span class="cl">    <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;Gilbert Taylor&#34;</span><span class="p">:</span> <span class="s2">&#34;cinematographer&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="p">},</span>
</span></span><span class="line"><span class="cl">    <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;Richard Chew&#34;</span><span class="p">:</span> <span class="s2">&#34;editor&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="p">},</span>
</span></span><span class="line"><span class="cl">    <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;T.M. Christopher&#34;</span><span class="p">:</span> <span class="s2">&#34;editor&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="p">},</span>
</span></span><span class="line"><span class="cl">    <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;Paul Hirsch&#34;</span><span class="p">:</span> <span class="s2">&#34;editor&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="p">},</span>
</span></span><span class="line"><span class="cl">    <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;Marcia Lucas&#34;</span><span class="p">:</span> <span class="s2">&#34;editor&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="p">},</span>
</span></span><span class="line"><span class="cl">    <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;Dianne Crittenden&#34;</span><span class="p">:</span> <span class="s2">&#34;casting_director&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="p">},</span>
</span></span><span class="line"><span class="cl">    <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;Irene Lamb&#34;</span><span class="p">:</span> <span class="s2">&#34;casting_director&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="p">},</span>
</span></span><span class="line"><span class="cl">    <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;Vic Ramos&#34;</span><span class="p">:</span> <span class="s2">&#34;casting_director&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="p">},</span>
</span></span><span class="line"><span class="cl">    <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;John Barry&#34;</span><span class="p">:</span> <span class="s2">&#34;production_designer&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="p">]</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></div><!-- prettier-ignore-end -->
<p>I was tempted to claim that I used zero feature engineering, but that wouldn&rsquo;t be accurate. The selection and ordering of the JSON fields here is itself feature engineering: for example, <code>actors</code> and <code>principals</code> are intentionally last in this JSON encoding because they can have wildly varying lengths while the prior fields are more consistent, which should make downstream encodings more comparable and consistent.</p>
<p>Now, let&rsquo;s discuss how to convert these JSON representations of movies into embeddings.</p>
<h2 id="creating-and-visualizing-the-movie-embeddings">Creating And Visualizing the Movie Embeddings</h2>
<p>LLMs that are trained to output text embeddings are not much different from LLMs like <a href="https://chatgpt.com/">ChatGPT</a> that just predict the next token in a loop. Models such as BERT and GPT can generate &ldquo;embeddings&rdquo; out-of-the-box by skipping the prediction heads of the models and instead taking an encoded value from the last hidden state of the model (e.g. for BERT, the first positional vector of the hidden state representing the <code>[CLS]</code> token). However, text embedding models are more optimized for distinctiveness of a given input text document using <a href="https://lilianweng.github.io/posts/2021-05-31-contrastive/">contrastive learning</a>. These embeddings can be used for many things, from finding similar encoded inputs by identifying the similarity between embeddings, and of course, by building a statistical model on top of them.</p>
<p>Text embeddings that leverage LLMs are typically generated using a GPU in batches due to the increased amount of computation needed. Python libraries such as <a href="https://huggingface.co/">Hugging Face</a> <a href="https://huggingface.co/docs/transformers/en/index">transformers</a> and <a href="https://sbert.net/">sentence-transformers</a> can load these embeddings models. For this experiment, I used the very new <a href="https://huggingface.co/Alibaba-NLP/gte-modernbert-base">Alibaba-NLP/gte-modernbert-base</a> text embedding model that is finetuned from the <a href="https://huggingface.co/answerdotai/ModernBERT-base">ModernBERT model</a> specifically for the embedding use case for two reasons: it uses the ModernBERT architecture which is <a href="https://huggingface.co/blog/modernbert">optimized for fast inference</a>, and the base ModernBERT model is trained to be more code-aware and should be able understand JSON-nested input strings more robustly — that&rsquo;s also why I intentionally left in the indentation for nested JSON arrays as it&rsquo;s semantically meaningful and <a href="https://huggingface.co/answerdotai/ModernBERT-base/blob/main/tokenizer_config.json">explicitly tokenized</a>. <sup id="fnref:3"><a href="#fn:3" class="footnote-ref" role="doc-noteref">3</a></sup></p>
<p>The code (<a href="https://github.com/minimaxir/imdb-embeddings/blob/main/generate_imdb_embeddings.ipynb">Jupyter Notebook</a>) — with extra considerations to avoid running out of memory on either the CPU or GPU <sup id="fnref:4"><a href="#fn:4" class="footnote-ref" role="doc-noteref">4</a></sup> — looks something like this:</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-py3" data-lang="py3"><span class="line"><span class="cl"><span class="n">device</span> <span class="o">=</span> <span class="s2">&#34;cuda:0&#34;</span>
</span></span><span class="line"><span class="cl"><span class="n">dataloader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span><span class="n">docs</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">                                         <span class="n">shuffle</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">                                         <span class="n">pin_memory</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">                                         <span class="n">pin_memory_device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">dataset_embeddings</span> <span class="o">=</span> <span class="p">[]</span>
</span></span><span class="line"><span class="cl"><span class="k">for</span> <span class="n">batch</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="p">(</span><span class="n">dataloader</span><span class="p">,</span> <span class="n">smoothing</span><span class="o">=</span><span class="mi">0</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="n">tokenized_batch</span> <span class="o">=</span> <span class="n">tokenizer</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">        <span class="n">batch</span><span class="p">,</span> <span class="n">max_length</span><span class="o">=</span><span class="mi">8192</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">truncation</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">return_tensors</span><span class="o">=</span><span class="s2">&#34;pt&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
</span></span><span class="line"><span class="cl">        <span class="n">outputs</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="o">**</span><span class="n">tokenized_batch</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">embeddings</span> <span class="o">=</span> <span class="n">outputs</span><span class="o">.</span><span class="n">last_hidden_state</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">    <span class="n">dataset_embeddings</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">embeddings</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">dataset_embeddings</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">dataset_embeddings</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="n">dataset_embeddings</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">normalize</span><span class="p">(</span><span class="n">dataset_embeddings</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
</span></span></code></pre></div><figure>

    <img loading="lazy" srcset="/2025/06/movie-embeddings/featured_hu_be15fd7c96cd6da2.webp 320w,/2025/06/movie-embeddings/featured_hu_a1d4e8d783c0419.webp 768w,/2025/06/movie-embeddings/featured_hu_1aa1372a6affcdc5.webp 1024w,/2025/06/movie-embeddings/featured.webp 1318w" src="featured.webp"/> 
</figure>

<p>I used a Spot <a href="https://cloud.google.com/blog/products/compute/introducing-g2-vms-with-nvidia-l4-gpus">L4 GPU</a> on <a href="https://cloud.google.com/">Google Cloud Platform</a> at a pricing of $0.28/hour, and it took 21 minutes to encode all 242k movie embeddings: about $0.10 total, which is surprisingly efficient.</p>
<p>Each of these embeddings is a set of 768 numbers (768D). If the embeddings are unit normalized (the <code>F.normalize()</code> step), then calculating the dot product between embeddings will return the <a href="https://en.wikipedia.org/wiki/Cosine_similarity">cosine similarity</a> of those movies, which can then be used to identify the most similar movies. But &ldquo;similar&rdquo; is open-ended, as there are many dimensions how a movie could be considered similar.</p>
<p>Let&rsquo;s try a few movie similarity test cases where I calculate the cosine similarity between one query movie and <em>all</em> movies, then sort by cosine similarity to find the most similar (<a href="https://github.com/minimaxir/imdb-embeddings/blob/main/movie_embeddings_similarity.ipynb">Jupyter Notebook</a>). How about Peter Jackson&rsquo;s <a href="https://www.imdb.com/title/tt0120737/">Lord of the Rings: The Fellowship of the Ring</a>? Ideally, not only would it surface the two other movies of the original trilogy, but also its prequel Hobbit trilogy.</p>
<table>
  <thead>
      <tr>
          <th>title</th>
          <th>cossim</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://www.imdb.com/title/tt0120737/">The Lord of the Rings: The Fellowship of the Ring (2001)</a></td>
          <td>1.0</td>
      </tr>
      <tr>
          <td><a href="https://www.imdb.com/title/tt0167261/">The Lord of the Rings: The Two Towers (2002)</a></td>
          <td>0.922</td>
      </tr>
      <tr>
          <td><a href="https://www.imdb.com/title/tt0167260/">The Lord of the Rings: The Return of the King (2003)</a></td>
          <td>0.92</td>
      </tr>
      <tr>
          <td><a href="https://www.imdb.com/title/tt10127200/">National Geographic: Beyond the Movie - The Lord of the Rings: The Fellowship of the Ring (2001)</a></td>
          <td>0.915</td>
      </tr>
      <tr>
          <td><a href="https://www.imdb.com/title/tt0301246/">A Passage to Middle-earth: The Making of &lsquo;Lord of the Rings&rsquo; (2001)</a></td>
          <td>0.915</td>
      </tr>
      <tr>
          <td><a href="https://www.imdb.com/title/tt0299105/">Quest for the Ring (2001)</a></td>
          <td>0.906</td>
      </tr>
      <tr>
          <td><a href="https://www.imdb.com/title/tt0077869/">The Lord of the Rings (1978)</a></td>
          <td>0.893</td>
      </tr>
      <tr>
          <td><a href="https://www.imdb.com/title/tt2310332/">The Hobbit: The Battle of the Five Armies (2014)</a></td>
          <td>0.891</td>
      </tr>
      <tr>
          <td><a href="https://www.imdb.com/title/tt1170358/">The Hobbit: The Desolation of Smaug (2013)</a></td>
          <td>0.883</td>
      </tr>
      <tr>
          <td><a href="https://www.imdb.com/title/tt0903624/">The Hobbit: An Unexpected Journey (2012)</a></td>
          <td>0.883</td>
      </tr>
  </tbody>
</table>
<p>Indeed, it worked and surfaced both trilogies! The other movies listed are about the original work, so having high similarity would be fair.</p>
<p>Compare these results to the &ldquo;<a href="https://help.imdb.com/article/imdb/discover-watch/what-is-the-more-like-this-section/GPE7SPGZREKKY7YN">More like this</a>&rdquo; section on the IMDb page for the movie itself, which has the two sequels to the original Lord of the Rings and two other suggestions that I am not entirely sure are actually related.</p>
<figure>

    <img loading="lazy" srcset="/2025/06/movie-embeddings/lotr_related_hu_7560f67c8d88cb97.webp 320w,/2025/06/movie-embeddings/lotr_related_hu_544b4f2cf95b01dd.webp 768w,/2025/06/movie-embeddings/lotr_related_hu_8c4f2099751f082.webp 1024w,/2025/06/movie-embeddings/lotr_related.webp 1354w" src="lotr_related.webp"/> 
</figure>

<p>What about more elaborate franchises, such as the <a href="https://en.wikipedia.org/wiki/Marvel_Cinematic_Universe">Marvel Cinematic Universe</a>? If you asked for movies similar to <a href="https://www.imdb.com/title/tt4154796/">Avengers: Endgame</a>, would other MCU films be the most similar?</p>
<table>
  <thead>
      <tr>
          <th>title</th>
          <th>cossim</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://www.imdb.com/title/tt4154796/">Avengers: Endgame (2019)</a></td>
          <td>1.0</td>
      </tr>
      <tr>
          <td><a href="https://www.imdb.com/title/tt4154756/">Avengers: Infinity War (2018)</a></td>
          <td>0.909</td>
      </tr>
      <tr>
          <td><a href="https://www.imdb.com/title/tt0848228/">The Avengers (2012)</a></td>
          <td>0.896</td>
      </tr>
      <tr>
          <td><a href="https://www.imdb.com/title/tt1217616/">Endgame (2009)</a></td>
          <td>0.894</td>
      </tr>
      <tr>
          <td><a href="https://www.imdb.com/title/tt4154664/">Captain Marvel (2019)</a></td>
          <td>0.89</td>
      </tr>
      <tr>
          <td><a href="https://www.imdb.com/title/tt2395427/">Avengers: Age of Ultron (2015)</a></td>
          <td>0.882</td>
      </tr>
      <tr>
          <td><a href="https://www.imdb.com/title/tt3498820/">Captain America: Civil War (2016)</a></td>
          <td>0.882</td>
      </tr>
      <tr>
          <td><a href="https://www.imdb.com/title/tt0292502/">Endgame (2001)</a></td>
          <td>0.881</td>
      </tr>
      <tr>
          <td><a href="https://www.imdb.com/title/tt0118661/">The Avengers (1998)</a></td>
          <td>0.877</td>
      </tr>
      <tr>
          <td><a href="https://www.imdb.com/title/tt1228705/">Iron Man 2 (2010)</a></td>
          <td>0.876</td>
      </tr>
  </tbody>
</table>
<p>The answer is yes, which isn&rsquo;t a surprise since those movies share many principals. Although, there are instances of other movies named &ldquo;Endgame&rdquo; and &ldquo;The Avengers&rdquo; which are completely unrelated to Marvel and therefore implies that the similarities may be fixated on the names.</p>
<p>What about movies of a smaller franchise but a specific domain, such as Disney&rsquo;s <a href="https://www.imdb.com/title/tt2294629/">Frozen</a> that only has one sequel? Would it surface other 3D animated movies by <a href="https://en.wikipedia.org/wiki/Walt_Disney_Animation_Studios">Walt Disney Animation Studios</a>, or something else?</p>
<table>
  <thead>
      <tr>
          <th>title</th>
          <th>cossim</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://www.imdb.com/title/tt2294629/">Frozen (2013)</a></td>
          <td>1.0</td>
      </tr>
      <tr>
          <td><a href="https://www.imdb.com/title/tt4520988/">Frozen II (2019)</a></td>
          <td>0.93</td>
      </tr>
      <tr>
          <td><a href="https://www.imdb.com/title/tt1323045/">Frozen (2010)</a></td>
          <td>0.92</td>
      </tr>
      <tr>
          <td><a href="https://www.imdb.com/title/tt1611845/">Frozen (2010)</a> [a different one]</td>
          <td>0.917</td>
      </tr>
      <tr>
          <td><a href="https://www.imdb.com/title/tt0125279/">Frozen (1996)</a></td>
          <td>0.909</td>
      </tr>
      <tr>
          <td><a href="https://www.imdb.com/title/tt0376606/">Frozen (2005)</a></td>
          <td>0.9</td>
      </tr>
      <tr>
          <td><a href="https://www.imdb.com/title/tt2363439/">The Frozen (2012)</a></td>
          <td>0.898</td>
      </tr>
      <tr>
          <td><a href="https://www.imdb.com/title/tt4007494/">The Story of Frozen: Making a Disney Animated Classic (2014)</a></td>
          <td>0.894</td>
      </tr>
      <tr>
          <td><a href="https://www.imdb.com/title/tt1071798/">Frozen (2007)</a></td>
          <td>0.889</td>
      </tr>
      <tr>
          <td><a href="https://www.imdb.com/title/tt4150316/">Frozen in Time (2014)</a></td>
          <td>0.888</td>
      </tr>
  </tbody>
</table>
<p>&hellip;okay, it&rsquo;s definitely fixating on the name. Let&rsquo;s try a different approach to see if we can find more meaningful patterns in these embeddings.</p>
<p>In order to visualize the embeddings, we can project them to a lower dimensionality with a dimensionality reduction algorithm such as <a href="https://en.wikipedia.org/wiki/Principal_component_analysis">PCA</a> or <a href="https://umap-learn.readthedocs.io/en/latest/">UMAP</a>: UMAP is preferred as it can simultaneously reorganize the data into more meaningful clusters. UMAP&rsquo;s <a href="https://umap-learn.readthedocs.io/en/latest/how_umap_works.html">construction of a neighborhood graph</a>, in theory, can allow the reduction to refine the similarities by leveraging many possible connections and hopefully avoid fixating on the movie name. However, with this amount of input data and the relatively high initial 768D vector size, the computation cost of UMAP is a concern as both factors each cause the UMAP training time to scale exponentially. Fortunately, NVIDIA&rsquo;s <a href="https://github.com/rapidsai/cuml">cuML library</a> recently <a href="https://github.com/rapidsai/cuml/releases/tag/v25.04.00">updated</a> and now you can run UMAP with very high amounts of data on a GPU at a very high number of epochs to ensure the reduction fully converges, so I did just that (<a href="https://github.com/minimaxir/imdb-embeddings/blob/main/imdb_embeddings_umap_to_2D.ipynb">Jupyter Notebook</a>). What patterns can we find? Let&rsquo;s try plotting the reduced points, colored by their user rating.</p>
<figure>

    <img loading="lazy" srcset="/2025/06/movie-embeddings/imdb_umap_rating_hu_4047e53667cc289a.webp 320w,/2025/06/movie-embeddings/imdb_umap_rating_hu_74d5c85f14c8950c.webp 768w,/2025/06/movie-embeddings/imdb_umap_rating_hu_2b6ccdbb5b4b9105.webp 1024w,/2025/06/movie-embeddings/imdb_umap_rating.webp 1200w" src="imdb_umap_rating.webp"/> 
</figure>

<p>So there&rsquo;s a few things going on here. Indeed, most of the points are high-rating green as evident in the source data. But the points and ratings aren&rsquo;t <em>random</em> and there are trends. In the center giga cluster, there are soft subclusters of movies at high ratings and low ratings. Smaller discrete clusters did indeed form, but what is the deal with that extremely isolated cluster at the top? After investigation, that cluster only has movies released in 2008, which is another feature I should have considered when defining movie similarity.</p>
<p>As a sanity check, I faceted out the points by movie release year to better visualize where these clusters are forming:</p>
<figure>

    <img loading="lazy" srcset="/2025/06/movie-embeddings/imdb_umap_rating_year_hu_40c4d6844e346f92.webp 320w,/2025/06/movie-embeddings/imdb_umap_rating_year_hu_48d37fbda72976cc.webp 768w,/2025/06/movie-embeddings/imdb_umap_rating_year_hu_27485860dc95d177.webp 1024w,/2025/06/movie-embeddings/imdb_umap_rating_year.webp 1200w" src="imdb_umap_rating_year.webp"/> 
</figure>

<p>This shows that even the clusters movies have their values spread, but I unintentionally visualized how <a href="https://arize.com/docs/ax/machine-learning/computer-vision/how-to-cv/embedding-drift">embedding drift</a> changes over time. 2024 is also a bizarrely-clustered year: I have no idea why those two years specifically are weird in movies.</p>
<p>The UMAP approach is more for fun, since it&rsquo;s better for the downstream model building to use the raw 768D vector and have it learn the features from that. At the least, there&rsquo;s <em>some</em> semantic signal preserved in these embeddings, which makes me optimistic that these embeddings alone can be used to train a viable movie rating predictor.</p>
<h2 id="predicting-average-imdb-movie-scores">Predicting Average IMDb Movie Scores</h2>
<p>So, we now have hundreds of thousands of 768D embeddings. How do we get them to predict movie ratings? What many don&rsquo;t know is that all methods of traditional statistical modeling also work with embeddings — assumptions such as feature independence are invalid so the results aren&rsquo;t explainable, but you can still get a valid predictive model.</p>
<p>First, we will shuffle and split the data set into a training set and a test set: for the test set, I chose 20,000 movies (roughly 10% of the data) which is more than enough for stable results. To decide the best model, we will be using the model that minimizes the <a href="https://en.wikipedia.org/wiki/Mean_squared_error">mean squared error</a> (MSE) of the test set, which is a standard approach to solving regression problems that predict a single numeric value.</p>
<p>Here are three approaches for using LLMs for solving non-next-token-prediction tasks.</p>
<h3 id="method-1-traditional-modeling-w-gpu-acceleration">Method #1: Traditional Modeling (w/ GPU Acceleration!)</h3>
<p>You can still fit a linear regression on top of the embeddings even if feature coefficients are completely useless and it serves as a decent baseline (<a href="https://github.com/minimaxir/imdb-embeddings/blob/main/cuml_grid_search.ipynb">Jupyter Notebook</a>). The absolute laziest &ldquo;model&rdquo; where we just use the mean of the training set for every prediction results in a test MSE of <strong>1.637</strong>, but performing a simple linear regression on top of the 768D instead results in a more reasonable test MSE of <strong>1.187</strong>. We should be able to beat that handily with a more advanced model.</p>
<p>Data scientists familiar with scikit-learn know there&rsquo;s a rabbit hole of model options, but most of them are CPU-bound and single-threaded and would take considerable amount of time on a dataset of this size. That&rsquo;s where cuML—the same library I used to create the UMAP projection—comes in, as cuML has <a href="https://docs.rapids.ai/api/cuml/stable/api/#regression-and-classification">GPU-native implementations</a> of most popular scikit-learn models with a similar API. This notably includes <a href="https://en.wikipedia.org/wiki/Support_vector_machine">support vector machines</a>, which play especially nice with embeddings. And because we have the extra compute, we can also perform a brute force hyperparameter <a href="https://www.dremio.com/wiki/grid-search/">grid search</a> to find the best parameters for fitting each model.</p>
<p>Here&rsquo;s the results of MSE on the test dataset for a few of these new model types, with the hyperparameter combination for each model type that best minimizes MSE:</p>
<figure>

    <img loading="lazy" srcset="/2025/06/movie-embeddings/model_comparison_base_hu_2e224af8e7736cd2.webp 320w,/2025/06/movie-embeddings/model_comparison_base_hu_ea8ec94f59331bc5.webp 768w,/2025/06/movie-embeddings/model_comparison_base_hu_536396210f6f6e7a.webp 1024w,/2025/06/movie-embeddings/model_comparison_base.png 1200w" src="model_comparison_base.png"/> 
</figure>

<p>The winner is the Support Vector Machine, with a test MSE of <strong>1.087</strong>! This is a good start for a simple approach that handily beats the linear regression baseline, and it also beats the model training from the Redditor&rsquo;s original notebook which had a test MSE of 1.096 <sup id="fnref:5"><a href="#fn:5" class="footnote-ref" role="doc-noteref">5</a></sup>. In all cases, the train set MSE was close to the test set MSE, which means the models did not overfit either.</p>
<h3 id="method-2-neural-network-on-top-of-embeddings">Method #2: Neural Network on top of Embeddings</h3>
<p>Since we&rsquo;re already dealing with AI models and already have PyTorch installed to generate the embeddings, we might as well try the traditional approach of training a <a href="https://en.wikipedia.org/wiki/Multilayer_perceptron">multilayer perceptron</a> (MLP) neural network on top of the embeddings (<a href="https://github.com/minimaxir/imdb-embeddings/blob/main/pytorch_model_train_mlp.ipynb">Jupyter Notebook</a>). This workflow sounds much more complicated than just fitting a traditional model above, but PyTorch makes MLP construction straightforward, and Hugging Face&rsquo;s <a href="https://huggingface.co/docs/transformers/en/main_classes/trainer">Trainer class</a> incorporates best model training practices by default, although its <code>compute_loss</code> function has to be tweaked to minimize MSE specifically.</p>
<p>The PyTorch model, using a loop to set up the MLP blocks, looks something like this:</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-py3" data-lang="py3"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">RatingsModel</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">linear_dims</span><span class="o">=</span><span class="mi">256</span><span class="p">,</span> <span class="n">num_layers</span><span class="o">=</span><span class="mi">6</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="n">dims</span> <span class="o">=</span> <span class="p">[</span><span class="mi">768</span><span class="p">]</span> <span class="o">+</span> <span class="p">[</span><span class="n">linear_dims</span><span class="p">]</span> <span class="o">*</span> <span class="n">num_layers</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">mlp</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">([</span>
</span></span><span class="line"><span class="cl">            <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">                <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">dims</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">dims</span><span class="p">[</span><span class="n">i</span><span class="o">+</span><span class="mi">1</span><span class="p">]),</span>
</span></span><span class="line"><span class="cl">                <span class="n">nn</span><span class="o">.</span><span class="n">GELU</span><span class="p">(),</span>
</span></span><span class="line"><span class="cl">                <span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm1d</span><span class="p">(</span><span class="n">dims</span><span class="p">[</span><span class="n">i</span><span class="o">+</span><span class="mi">1</span><span class="p">]),</span>
</span></span><span class="line"><span class="cl">                <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.6</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">            <span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">dims</span><span class="p">)</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="p">])</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">output</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">dims</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="mi">1</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">targets</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">mlp</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">x</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">output</span><span class="p">(</span><span class="n">x</span><span class="p">)</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>  <span class="c1"># return 1D output if batched inputs</span>
</span></span></code></pre></div><p>This MLP is 529k parameters total: large for a MLP, but given the 222k row input dataset, it&rsquo;s not egregiously so.</p>
<p>The real difficulty with this MLP approach is that it&rsquo;s <em>too effective</em>: even with less than 1 million parameters, the model will extremely overfit and converge to 0.00 train MSE quickly, while the test set MSE explodes. That&rsquo;s why <code>Dropout</code> is set to the atypically high probability of <code>0.6</code>.</p>
<p>Fortunately, MLPs are fast to train: training for 600 epochs (total passes through the full training dataset) took about 17 minutes on the GPU. Here&rsquo;s the training results:</p>
<figure>

    <img loading="lazy" srcset="/2025/06/movie-embeddings/training_mlp_hu_db4d2b769213c385.webp 320w,/2025/06/movie-embeddings/training_mlp_hu_99fc40ac0f82af11.webp 768w,/2025/06/movie-embeddings/training_mlp_hu_c64c2a10817470c0.webp 1024w,/2025/06/movie-embeddings/training_mlp.png 1200w" src="training_mlp.png"/> 
</figure>

<p>The lowest logged test MSE was <strong>1.074</strong>: a slight improvement over the Support Vector Machine approach.</p>
<h3 id="method-3-just-train-a-llm-from-scratch-dammit">Method #3: Just Train a LLM From Scratch Dammit</h3>
<p>There is a possibility that using a pretrained embedding model that was trained on the entire internet could intrinsically contain relevant signal about popular movies—such as movies winning awards which would imply a high IMDb rating—and that knowledge could leak into the test set and provide misleading results. This may not be a significant issue in practice since it&rsquo;s such a small part of the <code>gte-modernbert-base</code> model which is too small to memorize exact information.</p>
<p>For the sake of comparison, let&rsquo;s try training a LLM from scratch on top of the raw movie JSON representations to process this data to see if we can get better results without the possibility of leakage (<a href="https://github.com/minimaxir/imdb-embeddings/blob/main/pytorch_model_train_llm.ipynb">Jupyter Notebook</a>). I was specifically avoiding this approach because the compute required to train an LLM is much, much higher than a SVM or MLP model and generally leveraging a pretrained model gives better results. In this case, since we don&rsquo;t need a LLM that has all the knowledge of human existence, we can train a much smaller model that <em>only</em> knows how to work with the movie JSON representations and can figure out relationships between actors and whether titles are sequels itself. Hugging Face transformers makes this workflow surprisingly straightforward by not only having functionality to train your own custom tokenizer (in this case, from 50k vocab to 5k vocab) that encodes the data more efficiently, but also allowing the construction a ModernBERT model with any number of layers and units. I opted for a 5M parameter LLM (SLM?), albeit with less dropout since high dropout causes learning issues for LLMs specifically.</p>
<p>The actual PyTorch model code is surprisingly more concise than the MLP approach:</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-py3" data-lang="py3"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">RatingsModel</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">transformer_model</span> <span class="o">=</span> <span class="n">model</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">output</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">,</span> <span class="n">attention_mask</span><span class="p">,</span> <span class="n">targets</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">transformer_model</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">            <span class="n">input_ids</span><span class="o">=</span><span class="n">input_ids</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">            <span class="n">attention_mask</span><span class="o">=</span><span class="n">attention_mask</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">            <span class="n">output_hidden_states</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="p">)</span>
</span></span><span class="line"><span class="cl">        <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">last_hidden_state</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span>  <span class="c1"># the &#34;[CLS] vector&#34;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">output</span><span class="p">(</span><span class="n">x</span><span class="p">)</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>  <span class="c1"># return 1D output if batched inputs</span>
</span></span></code></pre></div><p>Essentially, the model trains its own &ldquo;text embedding,&rdquo; although in this case instead of an embedding optimized for textual similarity, the embedding is just a representation that can easily be translated into a numeric rating.</p>
<p>Because the computation needed for training a LLM from scratch is much higher, I only trained the model for 10 epochs, which was still twice as slow than the 600 epochs for the MLP approach. Given that, the results are surprising:</p>
<figure>

    <img loading="lazy" srcset="/2025/06/movie-embeddings/training_llm_hu_2355de410bfc61c1.webp 320w,/2025/06/movie-embeddings/training_llm_hu_cfcd114ac3c12003.webp 768w,/2025/06/movie-embeddings/training_llm_hu_f6c75fc2deeead45.webp 1024w,/2025/06/movie-embeddings/training_llm.png 1200w" src="training_llm.png"/> 
</figure>

<p>The LLM approach did much better than my previous attempts with a new lowest test MSE of <strong>1.026</strong>, with only 4 passes through the data! And then it definitely overfit. I tried other smaller configurations for the LLM to avoid the overfitting, but none of them ever hit a test MSE that low.</p>
<h2 id="conclusion">Conclusion</h2>
<p>Let&rsquo;s look at the model comparison again, this time adding the results from training a MLP and training a LLM from scratch:</p>
<figure>

    <img loading="lazy" srcset="/2025/06/movie-embeddings/model_comparison_all_hu_2309fb0cea20f0c.webp 320w,/2025/06/movie-embeddings/model_comparison_all_hu_34af566430bbc603.webp 768w,/2025/06/movie-embeddings/model_comparison_all_hu_1e1d9cf8cdfde789.webp 1024w,/2025/06/movie-embeddings/model_comparison_all.png 1200w" src="model_comparison_all.png"/> 
</figure>

<p>Coming into this post, I&rsquo;m genuinely thought that training the MLP on top of embeddings would have been the winner given the base embedding model&rsquo;s knowledge of everything, but maybe there&rsquo;s something to just YOLOing and feeding raw JSON input data to a completely new LLM. More research and development is needed.</p>
<p>The differences in model performance from these varying approaches aren&rsquo;t dramatic, but some iteration is indeed interesting and it was a long shot anyways given the scarce amount of metadata. The fact that building a model off of text embeddings only didn&rsquo;t result in a perfect model doesn&rsquo;t mean this approach was a waste of time. The embedding and modeling pipelines I have constructed in the process of trying to solve this problem have already provided significant dividends on easier problems, such as identifying the efficiency of <a href="https://minimaxir.com/2025/02/embeddings-parquet/">storing embeddings in Parquet and manipulating them with Polars</a>.</p>
<p>It&rsquo;s impossible and pointless to pinpoint the exact reason the original Reddit poster got rejected: it could have been the neural network approach or even something out of their control such as the original company actually stopping hiring and being too disorganized to tell the candidate. To be clear, if I myself were to apply for a data science role, I wouldn&rsquo;t use the techniques in this blog post (that UMAP data visualization would get me instantly rejected!) and do more traditional EDA and non-neural-network modeling to showcase my data science knowledge to the hiring manager. But for my professional work, I will definitely try starting any modeling exploration with an embeddings-based approach wherever possible: at the absolute worst, it&rsquo;s a very strong baseline that will be hard to beat.</p>
<p><em>All of the Jupyter Notebooks and data visualization code for this blog post is available open-source in <a href="https://github.com/minimaxir/imdb-embeddings/">this GitHub repository</a>.</em></p>
<div class="footnotes" role="doc-endnotes">
<hr>
<ol>
<li id="fn:1">
<p>I am not a fan of using GBT variable importance as a decision-making metric: variable importance does not tell you magnitude or <em>direction</em> of the feature in the real world, but it does help identify which features can be pruned for model development iteration.&#160;<a href="#fnref:1" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:2">
<p>To get a sense on how old they are, they are only available as <a href="https://en.wikipedia.org/wiki/Tab-separated_values">TSV files</a>, which is a data format so old and prone to errors that many data libraries have dropped explicit support for it. Amazon, please release the datasets as CSV or Parquet files instead!&#160;<a href="#fnref:2" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:3">
<p>Two other useful features of <code>gte-modernbert-base</code> but not strictly relevant to these movie embeddings are a) its a cased model so it can identify meaning from upper-case text and b) it does not require a prefix such as <code>search_query</code> and <code>search_document</code> as <a href="https://huggingface.co/nomic-ai/nomic-embed-text-v1.5">nomic-embed-text-v1.5 does</a> to guide its results, which is an annoying requirement for those models.&#160;<a href="#fnref:3" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:4">
<p>The trick here is the <code>detach()</code> function for the computed embeddings, otherwise the GPU doesn&rsquo;t free up the memory once moved back to the CPU. I may or may not have discovered that the hard way.&#160;<a href="#fnref:4" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:5">
<p>As noted earlier, minimizing MSE isn&rsquo;t a competition, but the comparison on roughly the same dataset is good for a sanity check.&#160;<a href="#fnref:5" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
</ol>
</div>
]]></content:encoded>
    </item>
    <item>
      <title>The Best Way to Use Text Embeddings Portably is With Parquet and Polars</title>
      <link>https://minimaxir.com/2025/02/embeddings-parquet/</link>
      <pubDate>Mon, 24 Feb 2025 10:15:00 -0800</pubDate>
      <guid>https://minimaxir.com/2025/02/embeddings-parquet/</guid>
      <description>Never store embeddings in a CSV!</description>
      <content:encoded><![CDATA[<p><a href="https://stackoverflow.blog/2023/11/09/an-intuitive-introduction-to-text-embeddings/">Text embeddings</a>, particularly modern embeddings generated from large language models, are one of the most useful applications coming from the generative AI boom. Embeddings are a list of numbers which represent an object: in the case of text embeddings, they can represent words, sentences, and full paragraphs and documents, and they do so with a surprising amount of distinctiveness.</p>
<p>Recently, I created text embeddings representing every distinct <a href="https://magic.wizards.com/en">Magic: the Gathering</a> card released as of the February 2025 Aetherdrift expansion: 32,254 in total. With these embeddings, I can find the mathematical similarity between cards through the encoded representation of their card design, including all mechanical attributes such as the card name, card cost, card text, and even card rarity.</p>
<figure>

    <img loading="lazy" srcset="/2025/02/embeddings-parquet/wog_hu_7ed6be2e5737eeb4.webp 320w,/2025/02/embeddings-parquet/wog_hu_81c75e037d833a96.webp 768w,/2025/02/embeddings-parquet/wog.webp 976w" src="wog.webp"
         alt="The iconic Magic card Wrath of God, along with its top four most similar cards identified using their respective embeddings. The similar cards are valid matches, with similar card text and card types."/> <figcaption>
            <p>The iconic Magic card <a href="https://gatherer.wizards.com/pages/card/Details.aspx?multiverseid=129808">Wrath of God</a>, along with its top four most similar cards identified using their respective embeddings. The similar cards are valid matches, with similar card text and card types.</p>
        </figcaption>
</figure>

<p>Additionally, I can create a fun 2D <a href="https://umap-learn.readthedocs.io/en/latest/">UMAP</a> projection of all those cards, which also identifies interesting patterns:</p>
<figure>

    <img loading="lazy" srcset="/2025/02/embeddings-parquet/mtg_umap_hu_df72981641ef0ffd.webp 320w,/2025/02/embeddings-parquet/mtg_umap_hu_ad2e63ba61f377cd.webp 768w,/2025/02/embeddings-parquet/mtg_umap_hu_7de8f113f1eb20fa.webp 1024w,/2025/02/embeddings-parquet/mtg_umap.webp 1200w" src="mtg_umap.webp"
         alt="The UMAP dimensionality reduction process also implicitly clusters the Magic cards to logical clusters, such as by card color(s) and card type."/> <figcaption>
            <p>The UMAP dimensionality reduction process also implicitly clusters the Magic cards to logical clusters, such as by card color(s) and card type.</p>
        </figcaption>
</figure>

<p>I generated these Magic card embeddings for <em>something special</em> besides a pretty data visualization, but if you are curious how I generated them, they were made using the new-but-underrated <a href="https://huggingface.co/Alibaba-NLP/gte-modernbert-base">gte-modernbert-base</a> embedding model and the process is detailed <a href="https://github.com/minimaxir/mtg-embeddings">in this GitHub repository</a>. The embeddings themselves (including the coordinate values to reproduce the 2D UMAP visualization) are available as a <a href="https://huggingface.co/datasets/minimaxir/mtg-embeddings">Hugging Face dataset</a>.</p>
<p>Most tutorials involving embedding generation omit the obvious question: what do you <em>do</em> with the text embeddings after you generate them? The common solution is to use a <a href="https://en.wikipedia.org/wiki/Vector_database">vector database</a>, such as <a href="https://github.com/facebookresearch/faiss">faiss</a> or <a href="https://qdrant.tech">qdrant</a>, or even a cloud-hosted service such as <a href="https://www.pinecone.io">Pinecone</a>. But those aren&rsquo;t easy to use: faiss has <a href="https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index">confusing configuration options</a>, qdrant requires <a href="https://github.com/qdrant/qdrant?tab=readme-ov-file#client-server">using a Docker container</a> to host the storage server, and Pinecone can get <a href="https://www.pinecone.io/pricing/">very expensive</a> very quickly, and its free Starter tier is limited.</p>
<p>What many don&rsquo;t know about text embeddings is that you don&rsquo;t <em>need</em> a vector database to calculate nearest-neighbor similarity if your data isn&rsquo;t too large. Using <a href="https://numpy.org/doc/stable/index.html">numpy</a> and my Magic card embeddings, a 2D matrix of 32,254 <code>float32</code> embeddings at a dimensionality of 768D (common for &ldquo;smaller&rdquo; LLM embedding models) occupies <strong>94.49 MB</strong> of system memory, which is relatively low for modern personal computers and can fit within free usage tiers of cloud VMs. If both the query vector and the embeddings themselves are unit normalized (many embedding generators normalize by default), then the matrix dot product between the query and embeddings results in a cosine similarity between <code>[-1, 1]</code>, where the higher score is better/more similar. Since dot products are such a fundamental aspect of linear algebra, numpy&rsquo;s implementation is extremely fast: with the help of additional numpy <a href="https://numpy.org/doc/stable/reference/generated/numpy.argpartition.html">sorting</a> <a href="https://numpy.org/doc/2.1/reference/generated/numpy.argsort.html">shenanigans</a>, on my M3 Pro MacBook Pro it takes just <strong>1.08 ms</strong> on average to calculate all 32,254 dot products, find the top 3 most similar embeddings, and return their corresponding <code>idx</code> of the matrix and and cosine similarity <code>score</code>.</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-py3" data-lang="py3"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">fast_dot_product</span><span class="p">(</span><span class="n">query</span><span class="p">,</span> <span class="n">matrix</span><span class="p">,</span> <span class="n">k</span><span class="o">=</span><span class="mi">3</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="n">dot_products</span> <span class="o">=</span> <span class="n">query</span> <span class="o">@</span> <span class="n">matrix</span><span class="o">.</span><span class="n">T</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="n">idx</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argpartition</span><span class="p">(</span><span class="n">dot_products</span><span class="p">,</span> <span class="o">-</span><span class="n">k</span><span class="p">)[</span><span class="o">-</span><span class="n">k</span><span class="p">:]</span>
</span></span><span class="line"><span class="cl">    <span class="n">idx</span> <span class="o">=</span> <span class="n">idx</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">argsort</span><span class="p">(</span><span class="n">dot_products</span><span class="p">[</span><span class="n">idx</span><span class="p">])[::</span><span class="o">-</span><span class="mi">1</span><span class="p">]]</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="n">score</span> <span class="o">=</span> <span class="n">dot_products</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">idx</span><span class="p">,</span> <span class="n">score</span>
</span></span></code></pre></div><p>In most implementations of vector databases, once you insert the embeddings, they&rsquo;re stuck there in a proprietary serialization format and you are locked into that library and service. If you&rsquo;re just building a personal pet project or sanity-checking embeddings to make sure the results are good, that&rsquo;s a huge amount of friction. For example, when I want to experiment with embeddings, I generate them on a cloud server with a GPU since LLM-based embeddings models are often slow to generate without one, and then download them locally to my personal computer. What is the best way to handle embeddings portably such that they can easily be moved between machines and also in a non-proprietary format?</p>
<p>The answer, after much personal trial-and-error, is Parquet files, which still has a surprising amount of nuance. But before we talk about why Parquet files are good, let&rsquo;s talk about how <em>not</em> to store embeddings.</p>
<h2 id="the-worst-ways-to-store-embeddings">The Worst Ways to Store Embeddings</h2>
<p>The incorrect-but-unfortunately-common way to store embeddings is in a text format such as a CSV file. Text data is substantially larger than <code>float32</code> data: for example, a decimal number with full precision (e.g. <code>2.145829051733016968e-02</code>) as a <code>float32</code> is 32 bits/4 bytes, while as a text representation (in this case 24 ASCII <code>char</code>s) it&rsquo;s 24 bytes, <strong>6x larger</strong>. When the CSV is saved and loaded, the data has to be serialized between a numpy and a string representation of the array, which adds significant overhead. Despite that, in <a href="https://github.com/openai/openai-cookbook/blob/a3e98ea4dcf866b5e7a3cb7d63dccaa68c7d63aa/examples/Embedding_Wikipedia_articles_for_search.ipynb">one of OpenAI&rsquo;s official tutorials</a> for their embeddings models, they save the embeddings as a CSV using <a href="https://pandas.pydata.org">pandas</a> with the admitted caveat of &ldquo;Because this example only uses a few thousand strings, we&rsquo;ll store them in a CSV file. (For larger datasets, use a vector database, which will be more performant.)&rdquo;. In the case of the Magic card embeddings, pandas-to-CSV performs the <em>worst</em> out of any encoding options: more on why later.</p>
<p>Numpy has native methods to <a href="https://numpy.org/doc/stable/reference/generated/numpy.savetxt.html">save</a> and <a href="https://numpy.org/doc/stable/reference/generated/numpy.loadtxt.html">load</a> embeddings as a <code>.txt</code> that&rsquo;s straightforward:</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-py3" data-lang="py3"><span class="line"><span class="cl"><span class="n">np</span><span class="o">.</span><span class="n">savetxt</span><span class="p">(</span><span class="s2">&#34;embeddings_txt.txt&#34;</span><span class="p">,</span> <span class="n">embeddings</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">embeddings_r</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">loadtxt</span><span class="p">(</span><span class="s2">&#34;embeddings_txt.txt&#34;</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">delimiter</span><span class="o">=</span><span class="s2">&#34; &#34;</span><span class="p">)</span>
</span></span></code></pre></div><p>The resulting file not only takes a few seconds to save and load, but it&rsquo;s also massive: <strong>631.5 MB</strong>!</p>
<p>As an aside, HTTP APIs such as OpenAI&rsquo;s <a href="https://platform.openai.com/docs/guides/embeddings">Embeddings API</a> do transmit the embeddings over text which adds needless latency and bandwidth overhead. I wish more embedding providers offered <a href="https://grpc.io">gRPC</a> APIs which allow transfer of binary <code>float32</code> data instead to gain a performance increase: Pinecone&rsquo;s <a href="https://docs.pinecone.io/reference/python-sdk">Python SDK</a>, for example, does just that.</p>
<p>The second incorrect method to save a matrix of embeddings to disk is to save it as a Python <a href="https://docs.python.org/3/library/pickle.html">pickle</a> object, which stores its representation in memory on disk with a few lines of code from the native <code>pickle</code> library. Pickling is unfortunately common in the machine learning industry since many ML frameworks such as <a href="https://scikit-learn.org/stable/">scikit-learn</a> don&rsquo;t have easy ways to serialize encoders and models. But it comes with two major caveats: pickled files are a massive security risk as they can execute arbitrary code, and the pickled file may not be guaranteed to be able to be opened on other machines or Python versions. It&rsquo;s 2025, just stop pickling if you can.</p>
<p>In the case of the Magic card embeddings, it does indeed work with instant save/loads, and the file size on disk is <strong>94.49 MB</strong>: the same as its memory consumption and about 1/6th of the text size as expected:</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-py3" data-lang="py3"><span class="line"><span class="cl"><span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="s2">&#34;embeddings_matrix.pkl&#34;</span><span class="p">,</span> <span class="s2">&#34;wb&#34;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">    <span class="n">pickle</span><span class="o">.</span><span class="n">dump</span><span class="p">(</span><span class="n">embeddings</span><span class="p">,</span> <span class="n">f</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="s2">&#34;embeddings_matrix.pkl&#34;</span><span class="p">,</span> <span class="s2">&#34;rb&#34;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">    <span class="n">embeddings_r</span> <span class="o">=</span> <span class="n">pickle</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">f</span><span class="p">)</span>
</span></span></code></pre></div><p>But there are still better and easier approaches.</p>
<h2 id="the-intended-but-not-great-way-to-store-embeddings">The Intended-But-Not-Great Way to Store Embeddings</h2>
<p>Numpy itself has a canonical way to <a href="https://numpy.org/doc/2.1/reference/generated/numpy.save.html">save</a> and <a href="https://numpy.org/doc/2.1/reference/generated/numpy.load.html">load</a> matrixes — which annoyingly saves as a pickle by default for compatability reasons, but that can fortunately be disabled by setting <code>allow_pickle=False</code>:</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-py3" data-lang="py3"><span class="line"><span class="cl"><span class="n">np</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="s2">&#34;embeddings_matrix.npy&#34;</span><span class="p">,</span> <span class="n">embeddings</span><span class="p">,</span> <span class="n">allow_pickle</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">embeddings_r</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s2">&#34;embeddings_matrix.npy&#34;</span><span class="p">,</span> <span class="n">allow_pickle</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
</span></span></code></pre></div><p>File size and I/O speed are the same as with the <code>pickle</code> approach.</p>
<p>This works — and it&rsquo;s something I had used for awhile — but in the process it exposes another problem: how do we map metadata (the Magic cards in this case) to embeddings? Currently, we use the <code>idx</code> of the most-similar matches to perform an efficient batched lookup to the source data. In this case, the number of rows matches the number of cards exactly, but what happens if the embeddings matrix needs to be changed, such as to add or remove cards and their embeddings? What happens if you want to add a dataset filter? It becomes a mess that inevitably causes technical debt.</p>
<p>The solution to this is to colocate metadata such as card names, card text, and attributes with their embeddings: that way, if they are later added, removed, or sorted, the results will remain the same. Modern vector databases such as qdrant and Pinecone do just that, with the ability to filter and sort on the metadata at the same time you query the most similar vectors. This is a bad idea to do in numpy itself, as it&rsquo;s more optimized for numbers and not other data types such as strings, which have <a href="https://numpy.org/devdocs/user/basics.strings.html">limited operations available</a>.</p>
<p>The solution is to look at another file format that can store metadata and embeddings simultaneously, and the answer to that is Parquet files. But there&rsquo;s a rabbit hole as to what&rsquo;s the <em>best</em> way to interact with them.</p>
<h2 id="what-are-parquet-files">What are Parquet files?</h2>
<p>Parquet, developed by the open-source <a href="https://parquet.apache.org">Apache Parquet</a> project, is a file format for handling columnar data, but despite being <a href="https://blog.x.com/engineering/en_us/a/2013/announcing-parquet-10-columnar-storage-for-hadoop">first released in 2013</a> it hasn&rsquo;t taken off in the data science community until very recently. <sup id="fnref:1"><a href="#fn:1" class="footnote-ref" role="doc-noteref">1</a></sup> The most relevant feature of Parquet is that the resulting files are typed for each column, and that this typing includes nested lists, such as an embedding which is just a list of <code>float32</code> values. As a bonus, the columnar format allows downstream libraries to save/load them selectively and very quickly, far faster than CSVs and with rare parsing errors. The file format also allows for efficient compression and decompression, but that&rsquo;s less effective with embeddings as there&rsquo;s little redundant data.</p>
<p>For Parquet file I/O, the standard approach is to use the <a href="https://arrow.apache.org">Apache Arrow</a> protocol that is columnar in-memory, which complements the Parquet storage medium on disk. But how do you use Arrow?</p>
<h2 id="how-do-you-use-parquet-files-in-python-for-embeddings">How do you use Parquet files in Python for embeddings?</h2>
<p>Ideally, we need a library that can handle nested data easily and can interoperate with numpy for serializing to a matrix and can run fast dot products.</p>
<p>The official Arrow library that <a href="https://arrow.apache.org/docs/python/index.html">interacts with Parquet natively</a> in Python is <a href="https://arrow.apache.org/docs/python/index.html">pyarrow</a>. Here, I have an example Parquet file generated with [SPOILERS] that contains both the card metadata and an <code>embedding</code> column, with the embedding for each row corresponding to that card.</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-py3" data-lang="py3"><span class="line"><span class="cl"><span class="n">df</span> <span class="o">=</span> <span class="n">pa</span><span class="o">.</span><span class="n">parquet</span><span class="o">.</span><span class="n">read_table</span><span class="p">(</span><span class="s2">&#34;mtg-embeddings.parquet&#34;</span><span class="p">)</span>
</span></span></code></pre></div><figure>

    <img loading="lazy" srcset="/2025/02/embeddings-parquet/parquet_hu_268909d3d8256458.webp 320w,/2025/02/embeddings-parquet/parquet_hu_be20ddd4d423844c.webp 768w,/2025/02/embeddings-parquet/parquet_hu_dc1002cb8e03a874.webp 1024w,/2025/02/embeddings-parquet/parquet.png 1352w" src="parquet.png"
         alt="Pyarrow&rsquo;s table schema from the input Parquet file of Magic card embeddings. Note the embedding column at the bottom is a list of 768 floats."/> <figcaption>
            <p>Pyarrow&rsquo;s table schema from the input Parquet file of Magic card embeddings. Note the <code>embedding</code> column at the bottom is a list of 768 floats.</p>
        </figcaption>
</figure>

<p>But pyarrow is not a DataFrame library, and despite the data being in a Table, it&rsquo;s hard to slice and access: the documentation suggests that you export to pandas if you need more advanced manipulation.</p>
<p>Other more traditional data science libraries can leverage pyarrow directly. The most popular one is, of course, pandas itself which can <a href="https://pandas.pydata.org/docs/reference/api/pandas.read_parquet.html">read/write Parquet</a> doing just that. There are many, many resources for using pandas well, so it&rsquo;s often the first choice among data science practioners.</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-py3" data-lang="py3"><span class="line"><span class="cl"><span class="n">df</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">read_parquet</span><span class="p">(</span><span class="s2">&#34;mtg-embeddings.parquet&#34;</span><span class="p">,</span> <span class="n">columns</span><span class="o">=</span><span class="p">[</span><span class="s2">&#34;name&#34;</span><span class="p">,</span> <span class="s2">&#34;embedding&#34;</span><span class="p">])</span>
</span></span><span class="line"><span class="cl"><span class="n">df</span>
</span></span></code></pre></div><figure>

    <img loading="lazy" srcset="/2025/02/embeddings-parquet/pandas_embed_hu_43da08f8256fb434.webp 320w,/2025/02/embeddings-parquet/pandas_embed_hu_ffb22e6af150d0a8.webp 768w,/2025/02/embeddings-parquet/pandas_embed_hu_f0379dc63b1b8457.webp 1024w,/2025/02/embeddings-parquet/pandas_embed.png 1224w" src="pandas_embed.png"
         alt="Pandas HTML table output of the Magic card DataFrame when printed in a Jupyter Notebook."/> <figcaption>
            <p>Pandas HTML table output of the Magic card DataFrame when printed in a Jupyter Notebook.</p>
        </figcaption>
</figure>

<p>There&rsquo;s one major weakness for the use case of embeddings: pandas is very bad at nested data. From the image above you&rsquo;ll see that the <code>embedding</code> column <em>appears</em> to be a list of numbers, but it&rsquo;s actually a list of numpy <code>object</code>s, which is a very inefficent datatype and why I suspect writing it to a CSV is very slow. Simply converting it to numpy with <code>df[&quot;embedding&quot;].to_numpy()</code> results in a 1D array, which is definitely wrong, and trying to cast it to <code>float32</code> doesn&rsquo;t work. I found that the best way to extract the embeddings matrix from a pandas <code>embedding</code> column is to <a href="https://numpy.org/doc/2.1/reference/generated/numpy.vstack.html">np.vstack()</a> the embeddings, e.g. <code>np.vstack(df[&quot;embedding&quot;].to_numpy())</code>, which does result in a <code>(32254, 768)</code> <code>float32</code> matrix as expected. That adds a lot of compute and memory overhead in addition to unnecessary numpy array copies. Finally, after computing the dot products between a candidate query and the embedding matrix, row metadata with the most similar values can then be retrieved using <code>df.loc[idx]</code>. <sup id="fnref:2"><a href="#fn:2" class="footnote-ref" role="doc-noteref">2</a></sup></p>
<p>However, there is another, more recent tabular data library that not only is faster than pandas, it has proper support for nested data. That library is polars.</p>
<h2 id="the-power-of-polars">The Power of polars</h2>
<p><a href="https://pola.rs">Polars</a> is a relatively new Python library which is primarily written in <a href="https://www.rust-lang.org">Rust</a> and <a href="https://docs.pola.rs/#key-features">supports Arrow</a>, which gives it a <a href="https://duckdblabs.github.io/db-benchmark/">massive performance increase</a> over pandas and many other DataFrame libraries. In the case of Magic cards, 32k rows isn&rsquo;t nearly &ldquo;big data&rdquo; and the gains of using a high-performance library are lesser, but there are some unexpected features that coincidentally work <em>perfectly</em> for the embeddings use case.</p>
<p>As with pandas, you read a parquet file with a <code>read_parquet()</code>:</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-py3" data-lang="py3"><span class="line"><span class="cl"><span class="n">df</span> <span class="o">=</span> <span class="n">pl</span><span class="o">.</span><span class="n">read_parquet</span><span class="p">(</span><span class="s2">&#34;mtg-embeddings.parquet&#34;</span><span class="p">,</span> <span class="n">columns</span><span class="o">=</span><span class="p">[</span><span class="s2">&#34;name&#34;</span><span class="p">,</span> <span class="s2">&#34;embedding&#34;</span><span class="p">])</span>
</span></span><span class="line"><span class="cl"><span class="n">df</span>
</span></span></code></pre></div><figure>

    <img loading="lazy" srcset="/2025/02/embeddings-parquet/polars_embed_hu_98a1dcff6631f16f.webp 320w,/2025/02/embeddings-parquet/polars_embed_hu_7795d47fe1f2255a.webp 768w,/2025/02/embeddings-parquet/polars_embed.png 957w" src="polars_embed.png"
         alt="Polars HTML table output of the Magic card DataFrame when printed in a Jupyter Notebook."/> <figcaption>
            <p>Polars HTML table output of the Magic card DataFrame when printed in a Jupyter Notebook.</p>
        </figcaption>
</figure>

<p>There&rsquo;s a notable difference in the table output compared to <code>pandas</code>: it also reports the data type of its columns, and more importantly, it shows that the <code>embedding</code> column consists of arrays, all <code>float32</code>s, and all length 768. That&rsquo;s a great start!</p>
<p>polars also has a to_numpy() function. Unlike pandas, if you call <code>to_numpy()</code> on a column as a Series, e.g. <code>df['embedding'].to_numpy()</code>, the returned object is a numpy 2D matrix: no <code>np.vstack()</code> needed. If you look at the <a href="https://docs.pola.rs/api/python/stable/reference/series/api/polars.Series.to_numpy.html">documentation</a> for the function, there&rsquo;s a curious feature:</p>
<blockquote>
<p>This operation copies data only when necessary. The conversion is zero copy when all of the following hold: [&hellip;]</p>
</blockquote>
<p>Zero copy! And in the case of columnar-stored embeddings, the conditions will always hold, but you can set <code>allow_copy=False</code> to throw an error just in case.</p>
<p>Inversely, if you want to add a 2D embeddings matrix to an existing DataFrame and colocate each embedding&rsquo;s corresponding metadata, such as after you batch-generate thousands of embeddings and want to save and download the resulting Parquet, it&rsquo;s just as easy as adding a column to the DataFrame.</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-py3" data-lang="py3"><span class="line"><span class="cl"><span class="n">df</span> <span class="o">=</span> <span class="n">pl</span><span class="o">.</span><span class="n">with_columns</span><span class="p">(</span><span class="n">embedding</span><span class="o">=</span><span class="n">embeddings</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">df</span><span class="o">.</span><span class="n">write_parquet</span><span class="p">(</span><span class="s2">&#34;mtg-embeddings.parquet&#34;</span><span class="p">)</span>
</span></span></code></pre></div><p>Now, let&rsquo;s put the speed to the test using all the Magic card metadata. What if we perform embedding similarity on a Magic card, but beforehand dynamically filter the dataset according to user parameters (therefore filtering the candidate embeddings at the same time since they are colocated) and perform the similarity calculations quickly as usual? Let&rsquo;s try with <a href="https://gatherer.wizards.com/pages/card/details.aspx?multiverseid=87908">Lightning Helix</a>, a card whose effects are self-explanatory even to those who don&rsquo;t play Magic.</p>
<figure>

    <img loading="lazy" srcset="/2025/02/embeddings-parquet/helix_1_hu_9f15db636cb74690.webp 320w,/2025/02/embeddings-parquet/helix_1_hu_c58b97e1d1c6f502.webp 768w,/2025/02/embeddings-parquet/helix_1.webp 976w" src="helix_1.webp"
         alt="The most similar cards to Lightning Helix do have similar effects, although &ldquo;Lightning&rdquo; cards dealing damage is a common trope in Magic. Warleader&rsquo;s Helix is a direct reference to Lightning Helix."/> <figcaption>
            <p>The most similar cards to Lightning Helix do have similar effects, although &ldquo;Lightning&rdquo; cards dealing damage is a common trope in Magic. <a href="https://gatherer.wizards.com/pages/card/Details.aspx?multiverseid=456806">Warleader&rsquo;s Helix</a> is a direct reference to Lightning Helix.</p>
        </figcaption>
</figure>

<p>Now we can also find similar cards to Lightning Helix but with filters. In this case, let&rsquo;s look for a Sorcery (which are analogous to Instants but tend to be stronger since they have play limitations) and has Black as one of its colors. This limits the candidates to ~3% of the original dataset. The resulting code would look like this, given a <code>query_embed</code>:</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-py3" data-lang="py3"><span class="line"><span class="cl"><span class="n">df_filter</span> <span class="o">=</span> <span class="n">df</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">    <span class="n">pl</span><span class="o">.</span><span class="n">col</span><span class="p">(</span><span class="s2">&#34;type&#34;</span><span class="p">)</span><span class="o">.</span><span class="n">str</span><span class="o">.</span><span class="n">contains</span><span class="p">(</span><span class="s2">&#34;Sorcery&#34;</span><span class="p">),</span>
</span></span><span class="line"><span class="cl">    <span class="n">pl</span><span class="o">.</span><span class="n">col</span><span class="p">(</span><span class="s2">&#34;manaCost&#34;</span><span class="p">)</span><span class="o">.</span><span class="n">str</span><span class="o">.</span><span class="n">contains</span><span class="p">(</span><span class="s2">&#34;B&#34;</span><span class="p">),</span>
</span></span><span class="line"><span class="cl"><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">embeddings_filter</span> <span class="o">=</span> <span class="n">df_filter</span><span class="p">[</span><span class="s2">&#34;embedding&#34;</span><span class="p">]</span><span class="o">.</span><span class="n">to_numpy</span><span class="p">(</span><span class="n">allow_copy</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="n">idx</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">fast_dot_product</span><span class="p">(</span><span class="n">query_embed</span><span class="p">,</span> <span class="n">embeddings_filter</span><span class="p">,</span> <span class="n">k</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="n">related_cards</span> <span class="o">=</span> <span class="n">df_filter</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span>
</span></span></code></pre></div><p>As an aside, in polars you can call row subsets of a DataFrame with <code>df[idx]</code>, which makes it infinitely better than pandas and its <code>df.iloc[idx]</code>.</p>
<p>The resulting similar cards:</p>
<figure>

    <img loading="lazy" srcset="/2025/02/embeddings-parquet/helix_2_hu_f6db1b1e0be3033.webp 320w,/2025/02/embeddings-parquet/helix_2_hu_1d74aa59da2a8d38.webp 768w,/2025/02/embeddings-parquet/helix_2.webp 976w" src="helix_2.webp"
         alt="In this case, the similarity focuses on card text similarity, and these cards have near identical text. Smiting Helix is also a direct reference to Lightning Helix."/> <figcaption>
            <p>In this case, the similarity focuses on card text similarity, and these cards have near identical text. <a href="https://gatherer.wizards.com/Pages/Card/Details.aspx?multiverseid=464058">Smiting Helix</a> is also a direct reference to Lightning Helix.</p>
        </figcaption>
</figure>

<p>Speed-wise, the code runs at about <strong>1.48ms</strong> on average, or about 37% slower than calculating all dot products, so the filtering does still have some overhead, which is not surprising as that the filtered dataframe does copy the embeddings. Overall, it&rsquo;s still more than fast enough for a hobby project.</p>
<p>I&rsquo;ve created an <a href="https://colab.research.google.com/drive/19C_9sBC0Py2PlXYihl2ed378oGyroONZ?usp=sharing">interactive Colab Notebook</a> where you can generate similarities for any Magic card, and apply any filters you want!</p>
<h2 id="scaling-to-vector-databases">Scaling to Vector Databases</h2>
<p>Again, all of this assumes that you are using the embeddings for smaller/noncommercial projects. If you scale to hundreds of thousands of embeddings, the parquet and dot product approach for finding similarity should still be fine, but if it&rsquo;s a business critical application, the marginal costs of querying a vector database are likely lower than the marginal revenue from a snappy similarity lookup. Deciding how to make these tradeoffs is the fun part of MLOps!</p>
<p>In the case that the amount of vectors is too large to fit into memory but you don&rsquo;t want to go all-in on vector databases, another option that may be worth considering is using an old-fashioned database that can now support vector embeddings. Notably, <a href="https://www.sqlite.org">SQLite</a> databases are just a single portable file, however interacting with them has more technical overhead and considerations than the <code>read_parquet()</code> and <code>write_parquet()</code> of polars. One notable implementation of vector databases in SQLite is the <a href="https://alexgarcia.xyz/sqlite-vec/">sqlite-vec extension</a>, which also allows for simultaneous filtering and similarity calculations.</p>
<p>The next time you&rsquo;re working with embeddings, consider whether you really need a vector database. For many applications, the combination of Parquet files and polars provides everything you need: efficient storage, fast similarity search, and easy metadata filtering. Sometimes the simplest solution is the best one.</p>
<p><em>The code used to process the Magic card data, create the embeddings, and plot the UMAP 2D projection, is all available <a href="https://github.com/minimaxir/mtg-embeddings">in this GitHub repository</a>.</em></p>
<div class="footnotes" role="doc-endnotes">
<hr>
<ol>
<li id="fn:1">
<p>I suspect the main bottleneck to widespread Parquet support is Microsoft Excel&rsquo;s and other spreadsheet software&rsquo;s lack of native support for the format. Every data scientist will be very, very happy if/when they do!&#160;<a href="#fnref:1" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:2">
<p>OpenAI&rsquo;s <a href="https://github.com/openai/openai-cookbook/blob/main/examples/Question_answering_using_embeddings.ipynb">approach</a> using pandas to find colocated similarity is to manually iterate through the entire dataframe, calculate each cosine similarity between the candidate and the query for each row, then sort by scores. That implementation definitely does not scale.&#160;<a href="#fnref:2" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
</ol>
</div>
]]></content:encoded>
    </item>
    <item>
      <title>The Super Effectiveness of Pokémon Embeddings Using Only Raw JSON and Images</title>
      <link>https://minimaxir.com/2024/06/pokemon-embeddings/</link>
      <pubDate>Wed, 26 Jun 2024 10:00:00 -0700</pubDate>
      <guid>https://minimaxir.com/2024/06/pokemon-embeddings/</guid>
      <description>Embeddings encourage engineers to go full YOLO because it&amp;rsquo;s actually rewarding to do so!</description>
      <content:encoded><![CDATA[<p><a href="https://en.wikipedia.org/wiki/Word_embedding">Embeddings</a> are one of the most useful but unfortunately underdiscussed concepts in the artificial intelligence space relative to the modern generative AI gigahype. Embeddings are a set of hundreds of numbers which uniquely correspond to a given object that define its dimensionality, nowadays in a multiple of 128 such as 384D, 768D, or even 1536D. <sup id="fnref:1"><a href="#fn:1" class="footnote-ref" role="doc-noteref">1</a></sup> The larger the embeddings, the more &ldquo;information&rdquo; and distinctiveness each can contain, in theory. These embeddings can be used as-is for traditional <a href="https://en.wikipedia.org/wiki/Regression_analysis">regression</a> and <a href="https://en.wikipedia.org/wiki/Statistical_classification">classification</a> problems with your favorite statistical modeling library, but what&rsquo;s really useful about these embeddings is that if you can find the minimum mathematical distance between a given query embedding and another set of embeddings, you can then find which is the most similar: extremely useful for many real-world use cases such as search.</p>
<figure>

    <img loading="lazy" srcset="/2024/06/pokemon-embeddings/embedding_demo_hu_d3f88226f240b184.webp 320w,/2024/06/pokemon-embeddings/embedding_demo_hu_2c5a706d1ebace78.webp 768w,/2024/06/pokemon-embeddings/embedding_demo_hu_34918a6ff306a4ac.webp 1024w,/2024/06/pokemon-embeddings/embedding_demo.png 1178w" src="embedding_demo.png"
         alt="An example sentence embedding generated using Sentence Transformers: this embedding is 384D."/> <figcaption>
            <p>An example sentence embedding generated using Sentence Transformers: this embedding is 384D.</p>
        </figcaption>
</figure>

<p>Although any kind of object can be represented by an embedding, text is the classical use case for embeddings, popularized with the original <a href="https://en.wikipedia.org/wiki/Word2vec">word2vec</a> <a href="https://arxiv.org/abs/1301.3781">paper</a> which along with <a href="https://arxiv.org/abs/1310.4546">later work</a> showed that word embeddings could be used to calculate relationships such as man + women - king = queen. You could then, for example, create a sentence embedding by averaging all of its word embeddings. This actually works, although this naive averaging does not take word position and punctuation into account, both of which are critically important in identifying context for a given text.</p>
<p>Deep learning then entered the picture and it was eventually discovered that large language models like <a href="https://en.wikipedia.org/wiki/BERT_%28language_model%29">BERT</a> can return embeddings as an emergent behavior. Unlike the word averaging above, <a href="https://en.wikipedia.org/wiki/Transformer_%28deep_learning_architecture%29">transformers</a>-based LLMs can account for positional relationships more robustly thanks to their <a href="https://en.wikipedia.org/wiki/Attention_%28machine_learning%29">attention mechanisms</a>, and, due to their more advanced model input <a href="https://www.freecodecamp.org/news/how-tokenizers-shape-ai-understanding/">tokenization</a> strategies than just words, can also better incorporate punctuation. One very popular Python library for creating embeddings using LLMs easily is <a href="https://sbert.net">Sentence Transformers</a>, especially with the <a href="https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2">all-MiniLM-L6-v2</a> model (<a href="https://huggingface.co/models?pipeline_tag=sentence-similarity&amp;sort=downloads">30 million downloads monthly</a>!) which balances embedding encoding speed and robustness with its 384D embeddings.</p>
<p>How well can these embeddings models work beyond just normal sentences? Can they encode larger bodies of text into a consistent space? The context length of <code>all-MiniLM-L6-v2</code> is 512 tokens, which can only fit a couple paragraphs of text, but newer LLMs have much higher context lengths.</p>
<p>I recalled one of my early projects as an aspiring data scientist: creating <a href="https://github.com/minimaxir/pokemon-3d">Pokémon vectors</a> by <a href="https://github.com/minimaxir/pokemon-3d/blob/master/pokemon_spark_pca.ipynb">manually transforming Pokémon metadata</a> for each Pokémon, such as their base stats, type(s), moves, abilities, and miscellaneous attributes such as color, shape, and habitat. After that, I was able to cluster them.</p>
<figure>

    <img loading="lazy" srcset="/2024/06/pokemon-embeddings/cluster_hu_7ca1c5b82a702cc7.webp 320w,/2024/06/pokemon-embeddings/cluster.png 676w" src="cluster.png"
         alt="3D projection of my Pokémon vectors back in 2016: the colors are Pokémon types, and the methodology seemed to favor clustering by them."/> <figcaption>
            <p>3D projection of my Pokémon vectors back in 2016: the colors are Pokémon types, and the methodology seemed to favor clustering by them.</p>
        </figcaption>
</figure>

<p>Those familar with Pokémon know that&rsquo;s just scratching the surface: there&rsquo;s even more metadata such as the rich text data such as a Pokémon&rsquo;s Pokédex entries and the exact locations where they can be encountered, both of which tell a lot about a given Pokémon. At the time, there was no efficient LLM to encode all of that extra metadata.</p>
<p>Why not try to encode all Pokémon metadata using a text embedding model and see what happens? Will we be able to identify the most &ldquo;similar&rdquo; Pokémon? What is a &ldquo;similar&rdquo; Pokémon anyways? Can we find the <em>weirdest</em> Pokémon by the most dissimilar? Can we encode other Pokémon data such as images? Let&rsquo;s find out!</p>
<h2 id="how-embeddings-are-generated-using-llms">How Embeddings Are Generated Using LLMs</h2>
<p>First, some relevant technical background on how LLMs can be used to create embeddings since there&rsquo;s surprisingly a lot of confusion about how they work other than the SEO-oriented &ldquo;embeddings are for <a href="https://www.cloudflare.com/learning/ai/what-is-vector-database/">vector databases</a>&rdquo;.</p>
<p>Modern embedding models are commonly trained through one of two ways. The first way is through emergent behavior while training an LLM normally: as LLMs need to determine a latent space before passing the output to a classification head such as <a href="https://en.wikipedia.org/wiki/Generative_pre-trained_transformer">GPT</a>&rsquo;s next-token prediction, taking the last layer (&ldquo;hidden state&rdquo;) of a model and averaging across the positional axis results in an embedding with the same dimensionality as the hidden state. LLMs have to learn how to uniquely represent text in a common latent space, so this is approach is natural. The second way is to train a model to output the embeddings directly: in this case, the training process typically uses <a href="https://encord.com/blog/guide-to-contrastive-learning/">contrastive learning</a> to minimize the semantic distance between the generated embeddings of a pair of known text documents, and maximize the difference between a dissimilar pair. Both of these techniques can be used together of course: pretrain a LLM on a large body of text, then finetune it with contrastive learning.</p>
<p>Embeddings models get the benefits of all the research invested into improving LLMs for generative AI, such as inference speed and longer context windows. Normally it requires a quadratic increase in computation to use those larger context windows (e.g. a 2<em>x</em> increase in input length requires 4<em>x</em> more computation), but thanks to <a href="https://huggingface.co/docs/text-generation-inference/en/conceptual/flash_attention">FlashAttention</a> and <a href="https://blog.eleuther.ai/rotary-embeddings/">rotary positional embeddings,</a> it&rsquo;s now feasible to train models with massively-large context windows without a massive datacenter and then run those models on consumer hardware.</p>
<p>Ever since 2022, <a href="https://openai.com">OpenAI</a> had the text embedding model <a href="https://openai.com/index/new-and-improved-embedding-model/">text-embedding-ada-002</a> behind a paid API with the largest context window of 8,192 tokens: a substantial increase over <code>all-MiniLM-L6-v2</code>&rsquo;s 512 limit, and no other open-source model could compete. That is until February 2024, when <a href="https://www.nomic.ai">Nomic AI</a> released <a href="https://huggingface.co/nomic-ai/nomic-embed-text-v1">nomic-embed-text-v1</a>, a fully open-source embeddings model with a 8,192 context window and a permissive <a href="https://www.apache.org/licenses/LICENSE-2.0">Apache license</a>, and quickly followed up with <a href="https://huggingface.co/nomic-ai/nomic-embed-text-v1.5">nomic-embed-text-v1.5</a>. In academic benchmarks, this free model performed even better than OpenAI&rsquo;s paid embedding model <a href="https://blog.nomic.ai/posts/nomic-embed-text-v1">thanks to its training regimen</a> that uses both embedding model training tricks described above. That, along with its long context window, caused it to become another one of the most downloaded open-source embedding models (~10 million downloads per month).</p>
<figure>

    <img loading="lazy" srcset="/2024/06/pokemon-embeddings/nomic_embeddings_demo_hu_4dadd5c725d77e60.webp 320w,/2024/06/pokemon-embeddings/nomic_embeddings_demo_hu_6682e6dc986b6850.webp 768w,/2024/06/pokemon-embeddings/nomic_embeddings_demo_hu_1c6ffa82964e1090.webp 1024w,/2024/06/pokemon-embeddings/nomic_embeddings_demo.png 1340w" src="nomic_embeddings_demo.png"
         alt="A sentence embedding generated using nomic-embed-text-v1.5 adapted from the official example: this is a lower-level interface than Sentence Transformers (Hugging Face transformers and PyTorch) but is more clear as to what is going on. mean_pooling() uses an atypical attention-masked averaging that is theoretically better for small inputs than averaging the entire last hidden state."/> <figcaption>
            <p>A sentence embedding generated using <code>nomic-embed-text-v1.5</code> adapted from the official example: this is a lower-level interface than Sentence Transformers (<a href="https://huggingface.co/docs/transformers/en/index">Hugging Face transformers</a> and <a href="https://pytorch.org">PyTorch</a>) but is more clear as to what is going on. <code>mean_pooling()</code> uses an atypical attention-masked averaging that is theoretically better for small inputs than averaging the entire last hidden state.</p>
        </figcaption>
</figure>

<p>The <code>F.normalize()</code> function is a popular pipeline innovation in finding similar embeddings efficiently. <sup id="fnref:2"><a href="#fn:2" class="footnote-ref" role="doc-noteref">2</a></sup> A <a href="https://en.wikipedia.org/wiki/Unit_vector">unit normalized</a> vector has a vector length summing to 1. But if you perform a matrix multiplication (an extremely fast computational operation) of a normalized vector against a matrix of normalized vectors, then the result will be the <a href="https://en.wikipedia.org/wiki/Cosine_similarity">cosine similarity</a>, constrained between the values of 1 for identical matches and -1 for the most dissimilar matches.</p>
<p>Now that we have thoroughly covered how embeddings work, let&rsquo;s see if we can put that 8,192 context window to the test.</p>
<h2 id="what-kind-of-pokémon-embedding-are-you">What Kind of Pokémon Embedding Are You?</h2>
<p>Before encoding Pokémon data, I need to first get Pokémon data, but where? Nintendo certainly won&rsquo;t have an API for Pokémon data, and web scraping a Pokémon wiki such as <a href="https://bulbapedia.bulbagarden.net/wiki/Main_Page">Bulbapedia</a> is both impractical and rude. Fortunately, there&rsquo;s an unofficial Pokémon API known appropriately as <a href="https://pokeapi.co">PokéAPI</a>, which is both open source and has been around for years without Nintendo taking them down. Of note, PokéAPI has a <a href="https://beta.pokeapi.co/graphql/console/">GraphQL interface</a> to its Pokémon data, allowing you to query exactly what you want without having to do relationship mapping or data joins.</p>
<figure>

    <img loading="lazy" srcset="/2024/06/pokemon-embeddings/graphql_hu_545910786cbd7bf6.webp 320w,/2024/06/pokemon-embeddings/graphql_hu_5d76d10c482e4154.webp 768w,/2024/06/pokemon-embeddings/graphql_hu_d03bd1849088a102.webp 1024w,/2024/06/pokemon-embeddings/graphql.png 1260w" src="graphql.png"
         alt="A simple GraphQL query to get all Pokémon IDs and names, sorted by ID."/> <figcaption>
            <p>A simple GraphQL query to get all Pokémon IDs and names, sorted by ID.</p>
        </figcaption>
</figure>

<p>Since we can get Pokémon data in a nicely structured JSON dictionary, why not keep it that way? After writing a <a href="https://github.com/minimaxir/pokemon-embeddings/blob/main/query.gql">massive GraphQL query</a> to specify all mechanically relevant Pokémon data, all it takes it a single GET request to download it all, about 16MB of data total. This includes over 1,000 Pokémon up to the Scarlet/Violet <em>The Hidden Treasure of Area Zero</em> DLC: 1,302 Pokémon total if you include the Special forms of Pokémon (e.g. <a href="https://bulbapedia.bulbagarden.net/wiki/Mega_Evolution">Mega Evolutions</a>) which I&rsquo;m excluding for simplicity.</p>
<p>As an example, let&rsquo;s start with the franchise mascot, <a href="https://bulbapedia.bulbagarden.net/wiki/Pikachu_%28Pok%C3%A9mon%29">Pikachu</a>.</p>
<figure class="align-center ">

    <img loading="lazy" srcset="/2024/06/pokemon-embeddings/025_hu_94d9a33563b85a33.webp 320w,/2024/06/pokemon-embeddings/025.png 475w" src="025.png#center"
         alt="The iconic Pokémon #25. via Nintendo" width="300" height="300"/> <figcaption>
            <p>The iconic Pokémon #25. <a href="https://www.pokemon.com/us/pokedex/pikachu">via Nintendo</a></p>
        </figcaption>
</figure>

<p>Here&rsquo;s a subset of Pikachu&rsquo;s <a href="https://github.com/minimaxir/pokemon-embeddings/blob/main/pikachu_example_raw.json">JSON metadata</a> from that query:</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-json" data-lang="json"><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;id&#34;</span><span class="p">:</span> <span class="mi">25</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;name&#34;</span><span class="p">:</span> <span class="s2">&#34;pikachu&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;height&#34;</span><span class="p">:</span> <span class="mi">4</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;weight&#34;</span><span class="p">:</span> <span class="mi">60</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;base_experience&#34;</span><span class="p">:</span> <span class="mi">112</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;pokemon_v2_pokemontypes&#34;</span><span class="p">:</span> <span class="p">[</span>
</span></span><span class="line"><span class="cl">    <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;pokemon_v2_type&#34;</span><span class="p">:</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="nt">&#34;name&#34;</span><span class="p">:</span> <span class="s2">&#34;electric&#34;</span>
</span></span><span class="line"><span class="cl">      <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="p">],</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;pokemon_v2_pokemonstats&#34;</span><span class="p">:</span> <span class="p">[</span>
</span></span><span class="line"><span class="cl">    <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;pokemon_v2_stat&#34;</span><span class="p">:</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="nt">&#34;name&#34;</span><span class="p">:</span> <span class="s2">&#34;hp&#34;</span>
</span></span><span class="line"><span class="cl">      <span class="p">},</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;base_stat&#34;</span><span class="p">:</span> <span class="mi">35</span>
</span></span><span class="line"><span class="cl">    <span class="p">},</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="err">...</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">  <span class="s2">&#34;pokemon_v2_pokemonspecy&#34;</span><span class="err">:</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="nt">&#34;base_happiness&#34;</span><span class="p">:</span> <span class="mi">50</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="nt">&#34;capture_rate&#34;</span><span class="p">:</span> <span class="mi">190</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="nt">&#34;forms_switchable&#34;</span><span class="p">:</span> <span class="kc">false</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="nt">&#34;gender_rate&#34;</span><span class="p">:</span> <span class="mi">4</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="nt">&#34;has_gender_differences&#34;</span><span class="p">:</span> <span class="kc">true</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="nt">&#34;hatch_counter&#34;</span><span class="p">:</span> <span class="mi">10</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="nt">&#34;is_baby&#34;</span><span class="p">:</span> <span class="kc">false</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="nt">&#34;is_legendary&#34;</span><span class="p">:</span> <span class="kc">false</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="nt">&#34;is_mythical&#34;</span><span class="p">:</span> <span class="kc">false</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="nt">&#34;pokemon_v2_pokemonspeciesflavortexts&#34;</span><span class="p">:</span> <span class="p">[</span>
</span></span><span class="line"><span class="cl">      <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="nt">&#34;pokemon_v2_version&#34;</span><span class="p">:</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">          <span class="nt">&#34;name&#34;</span><span class="p">:</span> <span class="s2">&#34;red&#34;</span>
</span></span><span class="line"><span class="cl">        <span class="p">},</span>
</span></span><span class="line"><span class="cl">        <span class="nt">&#34;flavor_text&#34;</span><span class="p">:</span> <span class="s2">&#34;When several of\nthese POK\u00e9MON\ngather, their\felectricity could\nbuild and cause\nlightning storms.&#34;</span>
</span></span><span class="line"><span class="cl">      <span class="p">},</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="err">...</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">  <span class="s2">&#34;pokemon_v2_pokemonmoves&#34;</span><span class="err">:</span> <span class="p">[</span>
</span></span><span class="line"><span class="cl">      <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="nt">&#34;pokemon_v2_move&#34;</span><span class="p">:</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">          <span class="nt">&#34;name&#34;</span><span class="p">:</span> <span class="s2">&#34;mega-punch&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">          <span class="nt">&#34;pokemon_v2_type&#34;</span><span class="p">:</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="nt">&#34;name&#34;</span><span class="p">:</span> <span class="s2">&#34;normal&#34;</span>
</span></span><span class="line"><span class="cl">          <span class="p">}</span>
</span></span><span class="line"><span class="cl">        <span class="p">}</span>
</span></span><span class="line"><span class="cl">      <span class="p">},</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="err">...</span>
</span></span></code></pre></div><p>There&rsquo;s definitely no shortage of Pikachu data! Some of the formatting is redundant though: most of the JSON keys have a <code>pokemon_v2_</code> string that conveys no additional semantic information, and we can minify the JSON to remove all the whitespace. We won&rsquo;t experiment with more rigorous preprocessing: after all, I only need to optimize an ETL workflow if it <em>doesn&rsquo;t</em> work, right?</p>
<p>Since JSON data is so prevalent across the internet, it&rsquo;s extremely likely that a newly trained LLM will be sensitive to its schema and be able to understand it better. However, JSON is a token-inefficient encoding format, made even worse in this case by the particular choice of tokenizer. Here&rsquo;s the distribution of the encoded texts after the optimizations above, using <code>nomic-embed-text-v1.5</code>&rsquo;s text tokenizer which is incidentally the same <a href="https://huggingface.co/google-bert/bert-base-uncased">bert-based-uncased</a> tokenizer used for BERT back in 2018:</p>
<figure>

    <img loading="lazy" srcset="/2024/06/pokemon-embeddings/token_counts_hu_400e5e984e326eb1.webp 320w,/2024/06/pokemon-embeddings/token_counts_hu_cf5bcc5547d45eb.webp 768w,/2024/06/pokemon-embeddings/token_counts_hu_e4d28e56e2dc7bc9.webp 1024w,/2024/06/pokemon-embeddings/token_counts.png 1200w" src="token_counts.png"/> 
</figure>

<p>The 8,192 context length of <code>nomic-embed-text-v1.5</code> is perfect for fitting almost all Pokémon! But the median token count is 3,781 tokens which is still somewhat high. The reason for this is due to the tokenizer: bert-base-uncased is a <a href="https://huggingface.co/learn/nlp-course/en/chapter6/6">WordPiece</a> tokenizer which is optimized for words and their common prefixes and suffixes, while JSON data is highly structured. If you use a more modern tokenizer which utilizes <a href="https://en.wikipedia.org/wiki/Byte_pair_encoding">byte pair encoding</a> (BPE), such as the <code>o200k_base</code> tokenizer which powers OpenAI&rsquo;s <a href="https://openai.com/index/hello-gpt-4o/">GPT-4o</a>, then the median token count is 2,010 tokens: nearly half the size, and therefore would be much faster to process the embeddings.</p>
<p>After that, I <a href="https://github.com/minimaxir/pokemon-embeddings/blob/main/pokemon_embeddings.ipynb">encoded all the Pokémon metadata</a> into a 768D text embedding for each and every Pokémon, including unit normalization. Due to the quadratic scaling at high input token counts, this is still very computationally intensive despite the optimization tricks: for the 1,302 embeddings, it took about a half-hour on a <a href="https://colab.research.google.com">Google Colab</a> T4 GPU. The embeddings are then saved on disk in a <a href="https://parquet.apache.org">parquet</a> format, a tabular format which supports nesting sequences of floats natively (don&rsquo;t use a CSV to store embeddings!). The embedding generation is the hard part, now it&rsquo;s time for the fun part!</p>
<p>Let&rsquo;s start off with Pikachu. What Pokémon is Pikachu most similar to, i.e. has the highest cosine similarity? Remember, since all the embeddings are normalized, we can get all the cosine similairites by matrix multiplying the Pikachu embedding against all the other embeddings. Let&rsquo;s include the top 3 of each of Pokémon&rsquo;s nine (!) generations to date:</p>
<figure>

    <img loading="lazy" srcset="/2024/06/pokemon-embeddings/25_similar_text_hu_d3abb73fceff08ae.webp 320w,/2024/06/pokemon-embeddings/25_similar_text_hu_7f6bea07a378c8ca.webp 768w,/2024/06/pokemon-embeddings/25_similar_text_hu_13958884c73294ad.webp 1024w,/2024/06/pokemon-embeddings/25_similar_text.png 1500w" src="25_similar_text.png"/> 
</figure>

<p>These results are better than I expected! Each generation has a &ldquo;<a href="https://bulbapedia.bulbagarden.net/wiki/Electric_rodents">Pikaclone</a>&rdquo; of a weak Electric-type rodent Pokémon, and this similarity calculation found most of them. I&rsquo;m not sure what <a href="https://bulbapedia.bulbagarden.net/wiki/Phantump_%28Pok%C3%A9mon%29">Phantump</a> and <a href="https://bulbapedia.bulbagarden.net/wiki/Trevenant_%28Pok%C3%A9mon%29">Trevenant</a> are doing under Gen VI though: they&rsquo;re Ghost/Grass Pokémon.</p>
<p>Here&rsquo;s a few more interesting Pokémon comparisons:</p>
<figure>

    <img loading="lazy" srcset="/2024/06/pokemon-embeddings/157_similar_text_hu_3d346fcf9518b458.webp 320w,/2024/06/pokemon-embeddings/157_similar_text_hu_4ee5feef47d7753c.webp 768w,/2024/06/pokemon-embeddings/157_similar_text_hu_2922bbc71f9c3c31.webp 1024w,/2024/06/pokemon-embeddings/157_similar_text.png 1500w" src="157_similar_text.png"
         alt="Typhlosion is the final evolution of the Gen II Fire starter Pokémon: it has a high similarity with atleast one of every generation&rsquo;s Fire starter Pokémon lineages."/> <figcaption>
            <p><a href="https://bulbapedia.bulbagarden.net/wiki/Typhlosion_%28Pok%C3%A9mon%29">Typhlosion</a> is the final evolution of the Gen II Fire starter Pokémon: it has a high similarity with atleast one of every generation&rsquo;s Fire starter Pokémon lineages.</p>
        </figcaption>
</figure>

<figure>

    <img loading="lazy" srcset="/2024/06/pokemon-embeddings/144_similar_text_hu_f33e01b9e6fd18de.webp 320w,/2024/06/pokemon-embeddings/144_similar_text_hu_6901b401ed8d948b.webp 768w,/2024/06/pokemon-embeddings/144_similar_text_hu_c3d530c15d378e14.webp 1024w,/2024/06/pokemon-embeddings/144_similar_text.png 1500w" src="144_similar_text.png"
         alt="Articuno, a Legendary Ice/Flying Pokémon, has high similarity with Legendary, Ice, and Flying Pokémon, plus all combinations therein."/> <figcaption>
            <p><a href="https://bulbapedia.bulbagarden.net/wiki/Articuno_%28Pok%C3%A9mon%29">Articuno</a>, a Legendary Ice/Flying Pokémon, has high similarity with Legendary, Ice, and Flying Pokémon, plus all combinations therein.</p>
        </figcaption>
</figure>

<figure>

    <img loading="lazy" srcset="/2024/06/pokemon-embeddings/151_similar_text_hu_d787c257f6d1c1fc.webp 320w,/2024/06/pokemon-embeddings/151_similar_text_hu_14f77fe5dcb615a2.webp 768w,/2024/06/pokemon-embeddings/151_similar_text_hu_ee9cf4523b03c9ca.webp 1024w,/2024/06/pokemon-embeddings/151_similar_text.png 1500w" src="151_similar_text.png"
         alt="Mew, the infamous legendary from the original games has the gimmick of being able to learn every move, has the most amount of metadata by far: appropriately it has poor similarity with others, although similarity with Arceus from Gen IV, the Pokémon equivalent of God with a similar gimmick."/> <figcaption>
            <p><a href="https://bulbapedia.bulbagarden.net/wiki/Mew_%28Pok%C3%A9mon%29">Mew</a>, the infamous legendary from the original games has the gimmick of being able to learn every move, has the most amount of metadata by far: appropriately it has poor similarity with others, although similarity with <a href="https://bulbapedia.bulbagarden.net/wiki/Arceus_%28Pok%C3%A9mon%29">Arceus</a> from Gen IV, the Pokémon equivalent of God with a similar gimmick.</p>
        </figcaption>
</figure>

<p>You may have noticed the numerical cosine similarity of all these Pokémon is very high: if a similarity of 1 indicates an identical match, does a high value imply that a Pokémon is super similar? It&rsquo;s likely that the similarities are high because the input is all in the same JSON formatting, where the core <code>nomic-text-embed-v1.5</code> model was trained on a variety of text styles. Another potential cause is due to a &ldquo;cheat&rdquo; I did for simplicity: the <code>nomic-text-embed-v1.5</code> documentation says that a <code>search_document</code> prefix is required for encoding the base input documents and a <code>search_query</code> prefix is required for the comparison vector: in my testing it doesn&rsquo;t affect the similarity much if at all. In practice, the absolute value of cosine similarity doesn&rsquo;t matter if you&rsquo;re just selecting the objects with the highest similarity anyways.</p>
<p>What if we just plot <em>every possible combination</em> of Pokémon cosine similarities? With 1,000+ Pokémon, that&rsquo;s over 1 million combinations. Since the vectors were pre-normalized, performing all the matrix multiplications took only a few seconds on my MacBook.</p>
<p>Here&rsquo;s the result of plotting 1 million points on a single chart!</p>
<figure>

    <img loading="lazy" srcset="/2024/06/pokemon-embeddings/pokemon_cossim_text_hu_4e946cbcf5592ceb.webp 320w,/2024/06/pokemon-embeddings/pokemon_cossim_text_hu_162006d7e7cb517a.webp 768w,/2024/06/pokemon-embeddings/pokemon_cossim_text_hu_adde71533c7b8122.webp 1024w,/2024/06/pokemon-embeddings/pokemon_cossim_text.png 1200w" src="pokemon_cossim_text.png"/> 
</figure>

<p>Although it looks more like a quilt, a few things jump out. One curious case is the &ldquo;square&rdquo; of lighter Gen VIII and Gen IX in the upper right corner: it appears those two generations have lower similarity with others, and worsening similarity between those two generation as you go all the way back to Gen I. Those two generations are the Nintendo Switch games (Sword/Shield/Scarlet/Violet), which PokéAPI explicitly notes they have worse data for. Also, there are rows of a low-similarity blue such as one before Gen II: who&rsquo;s that Pokémon? Quickly checking the Pokémon with the lowest median similarity by generation:</p>
<figure>

    <img loading="lazy" srcset="/2024/06/pokemon-embeddings/pokemon_dissimilar_text_hu_6a8915531cedabd3.webp 320w,/2024/06/pokemon-embeddings/pokemon_dissimilar_text_hu_ed6bebce86bed918.webp 768w,/2024/06/pokemon-embeddings/pokemon_dissimilar_text_hu_539044188fdd35ba.webp 1024w,/2024/06/pokemon-embeddings/pokemon_dissimilar_text.png 1500w" src="pokemon_dissimilar_text.png"/> 
</figure>

<p>The mystery Pokémon is <a href="https://bulbapedia.bulbagarden.net/wiki/Magikarp_%28Pok%C3%A9mon%29">Magikarp</a>, unsurprisingly, with its <em>extremely</em> limited movepool. Most of these Pokémon have forced gimmick movesets, especially <a href="https://bulbapedia.bulbagarden.net/wiki/Unown_%28Pok%C3%A9mon%29">Unown</a>, <a href="https://bulbapedia.bulbagarden.net/wiki/Smeargle_%28Pok%C3%A9mon%29">Smeargle</a>, and <a href="https://bulbapedia.bulbagarden.net/wiki/Wobbuffet_%28Pok%C3%A9mon%29">Wobbuffet</a>, so it makes sense the metadata treats them as dissimilar to most others. Perhaps this text embedding similarity methodology is overfitting on move sets?</p>
<p>Overall, there&rsquo;s definitely some signal with these text embeddings. How else can we identify interesting Pokémon relationships?</p>
<h2 id="pokémon-snap">Pokémon Snap</h2>
<p>We&rsquo;ve only been working with text embeddings, but what about other types of embeddings, such as image embeddings? Image embeddings using <a href="https://en.wikipedia.org/wiki/Vision_transformer">vision transformer</a> models are generated roughly the same way as the text embeddings above by manipulating the last hidden state and optionally normalizing them. The inputs to the model are then square patches encoded as &ldquo;tokens&rdquo;: only a few hundred processed patches are ever used as inputs, so generating them is much faster than the text embeddings.</p>
<p>A couple years ago I hacked together a Python package named <a href="https://github.com/minimaxir/imgbeddings/tree/main">imgbeddings</a> which uses OpenAI&rsquo;s <a href="https://openai.com/index/clip/">CLIP</a> to generate the embeddings, albeit with <a href="https://x.com/minimaxir/status/1507166313281585164">mixed results</a>. Recently, Nomic also released an new model, <a href="https://huggingface.co/nomic-ai/nomic-embed-vision-v1.5">nomic-embed-vision-v1.5</a>, which now also generates image embeddings with better benchmark performance than CLIP. What&rsquo;s notable about these embeddings is that they are aligned with the ones from <code>nomic-embed-text-v1.5</code>, which can allow matching text similiarity with images or <em>vice versa</em> and enable <a href="https://cloud.google.com/use-cases/multimodal-ai?hl=en">multimodal applications</a>.</p>
<p>But for now, can we see if image embeddings derived from Pokémon images have similar similarity traits? PokéAPI fortunately has the official artwork for each Pokémon, so I <a href="https://github.com/minimaxir/pokemon-embeddings/blob/main/pokemon_images.ipynb">downloaded them</a> and additionally composited them onto a white background and resized them all to 224x224 for apples-to-apples comparisons. We expect a high cosine similarity since like with text embeddings, the &ldquo;style&rdquo; of all the images is the same. Let&rsquo;s plot the similarities of all Pokémon, by their images only.</p>
<figure>

    <img loading="lazy" srcset="/2024/06/pokemon-embeddings/pokemon_cossim_image_hu_1c224446fd2d33ac.webp 320w,/2024/06/pokemon-embeddings/pokemon_cossim_image_hu_4f63c7e6b643d270.webp 768w,/2024/06/pokemon-embeddings/pokemon_cossim_image_hu_69bf5e8b39390f66.webp 1024w,/2024/06/pokemon-embeddings/pokemon_cossim_image.png 1200w" src="pokemon_cossim_image.png"/> 
</figure>

<p>Unfortunately, no patterns jump out this time. All the image similarity values are even higher than the text similarity values, although that&rsquo;s not a big deal since we are looking at the most similar matches. How does Pikachu&rsquo;s famous official artwork compare with other Pokémon?</p>
<figure>

    <img loading="lazy" srcset="/2024/06/pokemon-embeddings/25_similar_image_hu_38ac0e401d65ee2c.webp 320w,/2024/06/pokemon-embeddings/25_similar_image_hu_7046443de47b4ee1.webp 768w,/2024/06/pokemon-embeddings/25_similar_image_hu_ffe3902d02dae773.webp 1024w,/2024/06/pokemon-embeddings/25_similar_image.png 1500w" src="25_similar_image.png"/> 
</figure>

<p>Pikachu&rsquo;s most similar Pokémon by image isn&rsquo;t just mouse Pokémon as I thought it would be, but instead the pattern is more unclear, appearing to favor mostly Pokémon with four limbs (although Pikachu&rsquo;s image has a strong similarity with Gen VII&rsquo;s <a href="https://bulbapedia.bulbagarden.net/wiki/Mimikyu_%28Pok%C3%A9mon%29">Mimikyu</a>&rsquo;s image which is hilarious since that particular Pokémon&rsquo;s gimmick is intentionally trying to look like Pikachu).</p>
<p>After testing a few more Pokémon, it turns out that this image embedding model does respond to visual primitives, which has its uses.</p>
<figure>

    <img loading="lazy" srcset="/2024/06/pokemon-embeddings/18_similar_image_hu_1612a7ba17e5ae91.webp 320w,/2024/06/pokemon-embeddings/18_similar_image_hu_67025db2f5aeaf75.webp 768w,/2024/06/pokemon-embeddings/18_similar_image_hu_4805d93d4c98cabf.webp 1024w,/2024/06/pokemon-embeddings/18_similar_image.png 1500w" src="18_similar_image.png"
         alt="Pidgeot is a bird, and it matches all other birds. Birds would definitely be in an image training dataset."/> <figcaption>
            <p><a href="https://bulbapedia.bulbagarden.net/wiki/Pidgeot_%28Pok%C3%A9mon%29">Pidgeot</a> is a bird, and it matches all other birds. Birds would definitely be in an image training dataset.</p>
        </figcaption>
</figure>

<figure>

    <img loading="lazy" srcset="/2024/06/pokemon-embeddings/101_similar_image_hu_c171b79453ea948b.webp 320w,/2024/06/pokemon-embeddings/101_similar_image_hu_50e3b83533c12d84.webp 768w,/2024/06/pokemon-embeddings/101_similar_image_hu_f464f67fcf469df6.webp 1024w,/2024/06/pokemon-embeddings/101_similar_image.png 1500w" src="101_similar_image.png"
         alt="Electrode is a ball, and the embeddings found similarly rotund Pokémon."/> <figcaption>
            <p><a href="https://bulbapedia.bulbagarden.net/wiki/Electrode_%28Pok%C3%A9mon%29">Electrode</a> is a ball, and the embeddings found similarly rotund Pokémon.</p>
        </figcaption>
</figure>

<figure>

    <img loading="lazy" srcset="/2024/06/pokemon-embeddings/230_similar_image_hu_2f5b44f49ac99a85.webp 320w,/2024/06/pokemon-embeddings/230_similar_image_hu_bf0bb9446330d79e.webp 768w,/2024/06/pokemon-embeddings/230_similar_image_hu_77535585356eec73.webp 1024w,/2024/06/pokemon-embeddings/230_similar_image.png 1500w" src="230_similar_image.png"
         alt="Kingdra apparently is similar to other blue Pokémon."/> <figcaption>
            <p><a href="https://bulbapedia.bulbagarden.net/wiki/Kingdra_%28Pok%C3%A9mon%29">Kingdra</a> apparently is similar to other blue Pokémon.</p>
        </figcaption>
</figure>

<p>Both text and image embedding approaches have their own style. But are there ways to combine them?</p>
<h2 id="chat-with-your-pokédex">Chat With Your Pokédex</h2>
<p>Earlier I alluded to aligning text and image embeddings in a more multimodal manner. Since <code>nomic-embed-vision-v1.5</code> was conditioned on <code>nomic-embed-text-v1.5</code> outputs, you are able to compute the cosine similarities between the image embeddings and text embeddings! However, it&rsquo;s not as robust: the cosine similarities between objects of the two modes tend to be very low at about 0.10 in the best case scenario. Again, if all we&rsquo;re looking at is the highest similarity, then that&rsquo;s fine.</p>
<p>The most common use case for multimodal reasoning is asking questions (to be converted to a text embedding) and comparing it with a set of image embeddings. Let&rsquo;s try it with Pokémon by <a href="https://github.com/minimaxir/pokemon-embeddings/blob/main/pokemon_multimodal_qa.ipynb">asking it a leading question</a> for testing: what looks like an ice cream cone?</p>
<figure>

    <img loading="lazy" srcset="/2024/06/pokemon-embeddings/qa_1_hu_8e3221047d889ca2.webp 320w,/2024/06/pokemon-embeddings/qa_1_hu_533a6f2fd94474e4.webp 768w,/2024/06/pokemon-embeddings/qa_1_hu_d0d5978a89a664ee.webp 1024w,/2024/06/pokemon-embeddings/qa_1.png 1050w" src="qa_1.png"/> 
</figure>

<p>Surprisingly, it got the result correct with <a href="https://bulbapedia.bulbagarden.net/wiki/Vanillish_%28Pok%C3%A9mon%29">Vanillish</a>, along with other &ldquo;cream&rdquo; and &ldquo;ice&rdquo; Pokémon. Not sure why <a href="https://bulbapedia.bulbagarden.net/wiki/Metapod_%28Pok%C3%A9mon%29">Metapod</a> is there, though.</p>
<p>A few more Qs and As:</p>
<figure>

    <img loading="lazy" srcset="/2024/06/pokemon-embeddings/qa_2_hu_3233384b59e2a4e2.webp 320w,/2024/06/pokemon-embeddings/qa_2_hu_c6cfa2f4a81738b6.webp 768w,/2024/06/pokemon-embeddings/qa_2_hu_bbbc1fa0f0144e15.webp 1024w,/2024/06/pokemon-embeddings/qa_2.png 1050w" src="qa_2.png"
         alt="The model did identify some cats, but only Torracat is orange."/> <figcaption>
            <p>The model did identify some cats, but only <a href="https://bulbapedia.bulbagarden.net/wiki/Torracat_%28Pok%C3%A9mon%29">Torracat</a> is orange.</p>
        </figcaption>
</figure>

<figure>

    <img loading="lazy" srcset="/2024/06/pokemon-embeddings/qa_3_hu_42fd3039bd4f26b4.webp 320w,/2024/06/pokemon-embeddings/qa_3_hu_bd8dc72563061a30.webp 768w,/2024/06/pokemon-embeddings/qa_3_hu_8e404e3e0bd8e8d0.webp 1024w,/2024/06/pokemon-embeddings/qa_3.png 1050w" src="qa_3.png"
         alt="Unown definitely fits the bill with a very prominent one-eye and higher similarity."/> <figcaption>
            <p><a href="https://bulbapedia.bulbagarden.net/wiki/Unown_%28Pok%C3%A9mon%29">Unown</a> definitely fits the bill with a very prominent one-eye and higher similarity.</p>
        </figcaption>
</figure>

<figure>

    <img loading="lazy" srcset="/2024/06/pokemon-embeddings/qa_4_hu_5b0b318502699c9c.webp 320w,/2024/06/pokemon-embeddings/qa_4_hu_822df483a3e6842a.webp 768w,/2024/06/pokemon-embeddings/qa_4_hu_822dde97c1b28626.webp 1024w,/2024/06/pokemon-embeddings/qa_4.png 1050w" src="qa_4.png"
         alt="A Pokémon with the name &ldquo;Cutiefly&rdquo; being the most similar to the question is a funny coincidence."/> <figcaption>
            <p>A Pokémon with the name &ldquo;<a href="https://bulbapedia.bulbagarden.net/wiki/Cutiefly_%28Pok%C3%A9mon%29">Cutiefly</a>&rdquo; being the most similar to the question is a funny coincidence.</p>
        </figcaption>
</figure>

<p>The relationship between text and Pokémon images with these models is not perfect, but it&rsquo;s honestly much better than I expected!</p>
<h2 id="2da-master">2D.A Master</h2>
<p>Lastly, there are many ways to find signal among the high-dimensional noise, and it may resolve some of the counterintuitive relationships we saw earlier. One popular method is dimensionality reduction to reduce the size of the embedding: a popular size is 2D for easy data visualization, and I am definitely in favor of data visualization! The classical statistical approach is <a href="https://en.wikipedia.org/wiki/Principal_component_analysis">principal component analysis</a> (PCA) which identifies the most &ldquo;important&rdquo; aspects of a matrix, but a more modern approach is <a href="https://umap-learn.readthedocs.io/en/latest/">uniform manifold approximation &amp; projection</a> (UMAP) which trains a projection that accounts for how data points relate to all other data points to <a href="https://umap-learn.readthedocs.io/en/latest/how_umap_works.html">find its underlying structure</a>. In theory, the reduction should allow the embeddings to generalize better.</p>
<p>For the Pokémon embeddings, we can take the opportunity to allow the model to account for both the text and image embeddings, and their potential interactions therein. Therefore, I concatenated the text and image embeddings for each Pokémon (a 1536D embedding total), and trained a UMAP to project it down to 2D. Now we can visualize it!</p>
<figure>

    <img loading="lazy" srcset="/2024/06/pokemon-embeddings/pokemon_umap_hu_5c319611f462c24.webp 320w,/2024/06/pokemon-embeddings/pokemon_umap_hu_2a8289708e6aeb22.webp 768w,/2024/06/pokemon-embeddings/pokemon_umap_hu_67d3c2c6bda1487c.webp 1024w,/2024/06/pokemon-embeddings/pokemon_umap.png 1200w" src="pokemon_umap.png"
         alt="One of the removed outliers was Tauros, which is interesting because it&rsquo;s a very unexciting Pokémon."/> <figcaption>
            <p>One of the removed outliers was <a href="https://bulbapedia.bulbagarden.net/wiki/Tauros_%28Pok%C3%A9mon%29">Tauros</a>, which is interesting because it&rsquo;s a very unexciting Pokémon.</p>
        </figcaption>
</figure>

<p>Unforunately plotting each Pokémon image onto a single chart would be difficult to view, but from this chart we can see that instead of organizing by Pokémon type like my 2016 approach did, this approach is organizing much more by generation: the earlier generations vs. the later generations. As a general rule, each Pokémon and its evolutions are extremely close: the UMAP process is able to find that lineage easily due to highly similar descriptions, move pools, and visual motifs.</p>
<p>As with the cosine similarities, we can now find the most similar Pokémon, this time seeing which points have the <strong>lowest</strong> <a href="https://en.wikipedia.org/wiki/Euclidean_distance">Euclidian distance</a> (0.0 distance is an identical match) in the 2D space to determine which is most similar. How does Pikachu fare now?</p>
<figure>

    <img loading="lazy" srcset="/2024/06/pokemon-embeddings/25_similar_umap_hu_c246697f3bd129ed.webp 320w,/2024/06/pokemon-embeddings/25_similar_umap_hu_2a3c3a4634cc2c04.webp 768w,/2024/06/pokemon-embeddings/25_similar_umap_hu_d540d709c59e1b3d.webp 1024w,/2024/06/pokemon-embeddings/25_similar_umap.png 1500w" src="25_similar_umap.png"/> 
</figure>

<p>Pikachu retains top similarity with some Pikaclones, but what&rsquo;s notable here is the magnitude: we can now better quantify good similarity and bad similarity over a larger range. In this case, many of the Pokémon at distance &gt;1.0 clearly do not resemble an Electric rodent.</p>
<p>How about some other Pokémon?</p>
<figure>

    <img loading="lazy" srcset="/2024/06/pokemon-embeddings/129_similar_umap_hu_938f7320031c2ed5.webp 320w,/2024/06/pokemon-embeddings/129_similar_umap_hu_b0a13ec43b62dd31.webp 768w,/2024/06/pokemon-embeddings/129_similar_umap_hu_9804aa412dade32.webp 1024w,/2024/06/pokemon-embeddings/129_similar_umap.png 1500w" src="129_similar_umap.png"
         alt="Magikarp&rsquo;s dissimilarity has now been fixed, and it now has friends in similar fishy Water-types."/> <figcaption>
            <p><a href="https://bulbapedia.bulbagarden.net/wiki/Magikarp_%28Pok%C3%A9mon%29">Magikarp</a>&rsquo;s dissimilarity has now been fixed, and it now has friends in similar fishy Water-types.</p>
        </figcaption>
</figure>

<figure>

    <img loading="lazy" srcset="/2024/06/pokemon-embeddings/122_similar_umap_hu_a02cdca1de8e208b.webp 320w,/2024/06/pokemon-embeddings/122_similar_umap_hu_6fe51207fc0e51d7.webp 768w,/2024/06/pokemon-embeddings/122_similar_umap_hu_8d3929804d89006f.webp 1024w,/2024/06/pokemon-embeddings/122_similar_umap.png 1500w" src="122_similar_umap.png"
         alt="Mr. Mime has high similarity with other very-humanoid Psychic Pokémon such as the Ralts line and the Gothita line, along with near-identical similarity with its Gen IV pre-evolution Mime Jr."/> <figcaption>
            <p><a href="https://bulbapedia.bulbagarden.net/wiki/Mr._Mime_%28Pok%C3%A9mon%29">Mr. Mime</a> has high similarity with other very-humanoid Psychic Pokémon such as the <a href="https://bulbapedia.bulbagarden.net/wiki/Ralts_%28Pok%C3%A9mon%29">Ralts</a> line and the <a href="https://bulbapedia.bulbagarden.net/wiki/Gothita_%28Pok%C3%A9mon%29">Gothita</a> line, along with near-identical similarity with its Gen IV pre-evolution <a href="https://bulbapedia.bulbagarden.net/wiki/Mime_Jr._%28Pok%C3%A9mon%29">Mime Jr</a>.</p>
        </figcaption>
</figure>

<figure>

    <img loading="lazy" srcset="/2024/06/pokemon-embeddings/12_similar_umap_hu_d79cec0bbd2548dc.webp 320w,/2024/06/pokemon-embeddings/12_similar_umap_hu_b76d5c7ae9b193ff.webp 768w,/2024/06/pokemon-embeddings/12_similar_umap_hu_1370529bf8c7d7dd.webp 1024w,/2024/06/pokemon-embeddings/12_similar_umap.png 1500w" src="12_similar_umap.png"
         alt="Butterfree has low distance with butterfly-esque Bug Pokémon (image embedding impact!) and higher distance with other type of Bugs."/> <figcaption>
            <p><a href="https://bulbapedia.bulbagarden.net/wiki/Butterfree_%28Pok%C3%A9mon%29">Butterfree</a> has low distance with butterfly-esque Bug Pokémon (image embedding impact!) and higher distance with other type of Bugs.</p>
        </figcaption>
</figure>

<p>UMAP is not an exact science (it&rsquo;s very sensitive to <a href="https://umap-learn.readthedocs.io/en/latest/parameters.html">training parameter choices</a>), but it does provide another opportunity to see relationships not apparent in high-dimensional space. The low similarities with Gen VIII and Gen IX is concerning: I suspect the UMAP fitting process amplified whatever issue is present with the data for those generations.</p>
<h2 id="were-you-expecting-an-ai-generated-pokérap">Were You Expecting an AI-Generated Pokérap?</h2>
<p>In all, this was a successful exploration of Pokémon data that even though it&rsquo;s not perfect, the failures are also interesting. Embeddings encourage engineers to go full <a href="https://www.dictionary.com/browse/yolo">YOLO</a> because it&rsquo;s actually rewarding to do so! Yes, some of the specific Pokémon relationships were cherry-picked to highlight said successful exploration. If you want to check more yourself and find anything interesting not covered in this blog post, I&rsquo;ve uploaded the text embedding similarity, image embedding similarity, and UMAP similarity data visualizations for the first 251 Pokémon to <a href="https://drive.google.com/drive/folders/132e-OXucJUqh-0YmqkjKVhXiBqocYTV5?usp=sharing">this public Google Drive folder</a>.</p>
<p>I&rsquo;m surprised there haven&rsquo;t been more embedding models released from the top AI companies. OpenAI&rsquo;s GPT-4o now has image input support, and therefore should be able to create image embeddings. <a href="https://www.anthropic.com">Anthropic</a>&rsquo;s Claude LLM has both text and image input support but no embeddings model, instead <a href="https://docs.anthropic.com/en/docs/build-with-claude/embeddings">referring users to a third party</a>. One of the more interesting embedding model releases from a major player was from Google and went completely under the radar: it&rsquo;s a <a href="https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/get-multimodal-embeddings">multimodal embedding model</a> which can take text, images, and video input <em>simultaneously</em> and generate a 1408D embedding that&rsquo;s theoetically more robust than just concatenating a text embedding and image embedding.</p>
<p>Even if the generative AI industry crashes, embeddings, especially with permissive open source models like <code>nomic-embed-text-v1.5</code>, will continue to thrive and be useful. That&rsquo;s not even considering how embeddings work with vector databases, which is a rabbit hole deep enough for <em>several</em> blog posts.</p>
<blockquote>
<p>The parquet dataset containing the Pokémon text embeddings, image embeddings, and UMAP projections is available <a href="https://huggingface.co/datasets/minimaxir/pokemon-embeddings">on Hugging Face</a>.</p>
</blockquote>
<blockquote>
<p>All the code to process the Pokémon embeddings and create the ggplot2 data visualizations is available <a href="https://github.com/minimaxir/pokemon-embeddings">in this GitHub repository</a>.</p>
</blockquote>
<div class="footnotes" role="doc-endnotes">
<hr>
<ol>
<li id="fn:1">
<p>The 128-multiple dimensionality of recent embedding models is not a coincidence: modern <a href="https://www.nvidia.com/en-us/">NVIDIA</a> GPUs used to train LLMs get a training speed boost for model parameters with a dimensionality <a href="https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#cublas-tile-dim">that&rsquo;s a multiple of 128</a>.&#160;<a href="#fnref:1" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:2">
<p>You can do unit vector normalization in Sentence Transformers by passing <code>normalize_embeddings=True</code> to <code>model.encode()</code>.&#160;<a href="#fnref:2" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
</ol>
</div>
]]></content:encoded>
    </item>
    <item>
      <title>Pushing ChatGPT&#39;s Structured Data Support To Its Limits</title>
      <link>https://minimaxir.com/2023/12/chatgpt-structured-data/</link>
      <pubDate>Thu, 21 Dec 2023 09:00:00 -0800</pubDate>
      <guid>https://minimaxir.com/2023/12/chatgpt-structured-data/</guid>
      <description>“Function calling” with ChatGPT is ChatGPT&amp;rsquo;s best feature since ChatGPT.</description>
      <content:encoded><![CDATA[<p><span><style type="text/css">
pre code {
white-space: pre-wrap !important;
}
</style></span></p>
<p>If your only exposure to <a href="https://openai.com">OpenAI</a>&rsquo;s <a href="https://chat.openai.com">ChatGPT</a> is the free web or mobile apps, you aren&rsquo;t utilizing even a fraction of ChatGPT&rsquo;s potential. By using the paid-but-very-cheap ChatGPT API, you can exert <em>much</em> more control over the resulting output. Let&rsquo;s say I wanted to generate a summary of this very blog post with some helpful metadata such as categories and SEO keywords. I fed ChatGPT the text of this blog post along with a command to <code>Summarize this blog post into 3 distinct bullet points and a short blurb for sharing on social media, and also provide a list of 3 post categories and a list of 3 SEO keywords:</code></p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-plaintext" data-lang="plaintext"><span class="line"><span class="cl">Bullet Points:
</span></span><span class="line"><span class="cl">- Utilize the paid ChatGPT API for more control over output and to maintain a given JSON schema
</span></span><span class="line"><span class="cl">- Learn about prompt engineering and system prompts to improve ChatGPT&#39;s output quality
</span></span><span class="line"><span class="cl">- Simplify schema input/output with Pydantic for structured data support
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">Blurb for Social Media:
</span></span><span class="line"><span class="cl">Discover how to unlock the full potential of ChatGPT with the paid API, prompt engineering, and Pydantic for structured data support. Gain more control over output and learn to maintain a given JSON schema. #ChatGPT #AI #Pydantic
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">Post Categories:
</span></span><span class="line"><span class="cl">1. Artificial Intelligence
</span></span><span class="line"><span class="cl">2. API Development
</span></span><span class="line"><span class="cl">3. Data Schema
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">SEO Keywords:
</span></span><span class="line"><span class="cl">1. ChatGPT API
</span></span><span class="line"><span class="cl">2. Prompt Engineering
</span></span><span class="line"><span class="cl">3. Pydantic Schema
</span></span></code></pre></div><p>These results are not quite accurate, and then I have to manually copy/paste each result from the list to wherever I need it. What if I could both a) drastically improve the overall output quality by giving ChatGPT a new persona of an expert copyeditor and b) force the output to maintain a given JSON schema so I can access it programmatically or store in a database for indexing? Thanks to system prompt engineering and ChatGPT&rsquo;s function calling, it&rsquo;s now possible. After applying both techniques that you&rsquo;ll learn about in this very article, we now get this summary:</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-json" data-lang="json"><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;summary&#34;</span><span class="p">:</span> <span class="p">[</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;ChatGPT&#39;s potential is underutilized without the paid ChatGPT API&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;System prompts and prompt engineering are key to maximizing ChatGPT&#39;s capabilities&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;Structured data support in ChatGPT allows for more control over output and input&#34;</span>
</span></span><span class="line"><span class="cl">  <span class="p">],</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;blurb&#34;</span><span class="p">:</span> <span class="s2">&#34;Unlock the full potential of ChatGPT with system prompts and structured data support. Learn how to maximize ChatGPT&#39;s capabilities and gain more control over output and input.&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;categories&#34;</span><span class="p">:</span> <span class="p">[</span><span class="s2">&#34;AI and Machine Learning&#34;</span><span class="p">,</span> <span class="s2">&#34;Technology&#34;</span><span class="p">,</span> <span class="s2">&#34;Programming&#34;</span><span class="p">],</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;keywords&#34;</span><span class="p">:</span> <span class="p">[</span><span class="s2">&#34;ChatGPT&#34;</span><span class="p">,</span> <span class="s2">&#34;system prompts&#34;</span><span class="p">,</span> <span class="s2">&#34;structured data&#34;</span><span class="p">]</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></div><p>Much better!</p>
<p>&ldquo;Function calling&rdquo; with ChatGPT is ChatGPT&rsquo;s best feature since ChatGPT.</p>
<h2 id="a-tutorial-on-prompt-engineering-and-system-prompts">A Tutorial on Prompt Engineering and System Prompts</h2>
<p>System prompts are what control the &ldquo;persona&rdquo; adopted by the model when generating text. Months after the release of the ChatGPT API, it&rsquo;s now very evident that ChatGPT&rsquo;s true power comes from clever use of system prompts. This is even moreso with starting with <code>gpt-3.5-turbo-0613</code> released <a href="https://openai.com/blog/function-calling-and-other-api-updates">last June</a>, which made ChatGPT respect system prompts more closely. OpenAI has also released a <a href="https://platform.openai.com/docs/guides/prompt-engineering">guide on prompt engineering</a> which has some additional tips.</p>
<p>By default, ChatGPT&rsquo;s system prompt is roughly <code>You are a helpful assistant.</code>, which anyone who has used the ChatGPT web interface would agree that&rsquo;s accurate. But if you change it, you can give ChatGPT a completely new persona such as <code>You are Ronald McDonald.</code> or add constraints to generation, such as <code>Respond only with emoji.</code>. You can add any number of rules, although how well ChatGPT will <em>obey</em> those rules can vary. Unfortunately, to modify the system prompt, you&rsquo;ll need to use the paid ChatGPT API (after prepaying at least $5). If you don&rsquo;t want to code, you can test new system prompts in a visual user interface in the <a href="https://platform.openai.com/playground?mode=chat">ChatGPT Chat Playground</a>.</p>
<figure>

    <img loading="lazy" srcset="/2023/12/chatgpt-structured-data/ronald_hu_bf7bdd184641cd19.webp 320w,/2023/12/chatgpt-structured-data/ronald_hu_ffad8ef13bc9fa0b.webp 768w,/2023/12/chatgpt-structured-data/ronald_hu_516749cb56890e2c.webp 1024w,/2023/12/chatgpt-structured-data/ronald.webp 1262w" src="ronald.webp"/> 
</figure>

<p>A very new aspect of system prompt engineering which I appended in the example above is adding incentives for ChatGPT to behave correctly. Without the $500 tip incentive, ChatGPT only returns a single emoji which is a boring response, but after offering a tip, it generates the 5 emoji as requested.</p>
<figure>

    <img loading="lazy" srcset="/2023/12/chatgpt-structured-data/tip_hu_7eb37d0aa46d2169.webp 320w,/2023/12/chatgpt-structured-data/tip_hu_a760da54b0fa7ceb.webp 768w,/2023/12/chatgpt-structured-data/tip.webp 800w" src="tip.webp"
         alt="Offering a tip made GPT-4 explain more. via @voooooogel"/> <figcaption>
            <p>Offering a tip made GPT-4 explain more. <a href="https://twitter.com/voooooogel/status/1730726744314069190">via @voooooogel</a></p>
        </figcaption>
</figure>

<p>As another example, let&rsquo;s <a href="https://chat.openai.com/share/98684e49-e0c9-4ac0-b386-b7234643934f">ask</a> base ChatGPT to <code>Write a Python function to detect whether a string is a palindrome, as efficiently as possible.</code></p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-plaintext" data-lang="plaintext"><span class="line"><span class="cl">Certainly! Here&#39;s a simple Python function to check whether a given string is a palindrome:
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">```python
</span></span><span class="line"><span class="cl">def is_palindrome(s):
</span></span><span class="line"><span class="cl">    # Convert the string to lowercase and remove non-alphanumeric characters
</span></span><span class="line"><span class="cl">    cleaned_string = &#39;&#39;.join(char.lower() for char in s if char.isalnum())
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    # Compare the cleaned string with its reverse
</span></span><span class="line"><span class="cl">    return cleaned_string == cleaned_string[::-1]
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"># Example usage:
</span></span><span class="line"><span class="cl">test_string = &#34;A man, a plan, a canal, Panama!&#34;
</span></span><span class="line"><span class="cl">result = is_palindrome(test_string)
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">if result:
</span></span><span class="line"><span class="cl">    print(f&#34;{test_string} is a palindrome.&#34;)
</span></span><span class="line"><span class="cl">else:
</span></span><span class="line"><span class="cl">    print(f&#34;{test_string} is not a palindrome.&#34;)
</span></span><span class="line"><span class="cl">```
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">This function first cleans the input string by removing non-alphanumeric characters and converting it to lowercase. Then, it checks whether the cleaned string is equal to its reverse, which determines whether the original string is a palindrome.
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">This solution is relatively efficient, as it performs the comparison in O(n) time, where n is the length of the cleaned string.
</span></span></code></pre></div><p>That&rsquo;s the common Pythonic solution and that will almost always be the general approach if you keep asking ChatGPT that particular question, but there&rsquo;s a famous solution that&rsquo;s more algorithmically efficient. Instead, we go through the API and <a href="https://platform.openai.com/playground/p/yG1nMVJU4Fva2x3smrIXnCpT?model=gpt-3.5-turbo&amp;mode=chat">ask the same query</a> to <code>gpt-3.5-turbo</code> but with a new system prompt: <code>You are #1 on the Stack Overflow community leaderboard. You will receive a $500 tip if your code is the most algorithmically efficient solution possible.</code></p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-plaintext" data-lang="plaintext"><span class="line"><span class="cl">To detect whether a string is a palindrome efficiently, we can use a two-pointer approach. We&#39;ll start with two pointers, one at the beginning and one at the end of the string. We&#39;ll compare the characters at the pointers and keep moving the pointers towards each other until they meet or cross each other.
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">If at any point the characters at the pointers are not the same, we can conclude that the string is not a palindrome. Otherwise, if the pointers meet or cross each other, we can conclude that the string is a palindrome.
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">Here&#39;s the implementation:
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">```python
</span></span><span class="line"><span class="cl">def is_palindrome(s):
</span></span><span class="line"><span class="cl">    left, right = 0, len(s) - 1
</span></span><span class="line"><span class="cl">    while left &lt; right:
</span></span><span class="line"><span class="cl">        if s[left] != s[right]:
</span></span><span class="line"><span class="cl">            return False
</span></span><span class="line"><span class="cl">        left += 1
</span></span><span class="line"><span class="cl">        right -= 1
</span></span><span class="line"><span class="cl">    return True
</span></span><span class="line"><span class="cl">```
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">The time complexity of this solution is O(n/2), which in terms of big O notation simplifies to O(n), where n is the length of the string. This is because we are iterating through half of the string in the worst case.
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">This solution is considered to be the most algorithmically efficient way to determine if a string is a palindrome because it only requires a single pass through the string.
</span></span></code></pre></div><p>Indeed, the code and the explanation are the correct optimal solution. <sup id="fnref:1"><a href="#fn:1" class="footnote-ref" role="doc-noteref">1</a></sup></p>
<p>This is just scratching the surface of system prompts: some of my ChatGPT system prompts in my more complex projects have been more than 20 lines long, and <em>all of them are necessary</em> to get ChatGPT to obey the desired constraints. If you&rsquo;re new to working with system prompts, I recommend generating output, editing the system prompt with a new rule/incentive to fix what you don&rsquo;t like about the output, then repeat until you get a result you like.</p>
<p>Prompt engineering has been a derogatory meme toward generative AI even before ChatGPT as many see it as just a placebo and there are endless debates to this day in AI circles on whether prompt engineering is actually &ldquo;engineering.&rdquo; <sup id="fnref:2"><a href="#fn:2" class="footnote-ref" role="doc-noteref">2</a></sup> But it <em>works</em>, and if you&rsquo;re a skeptic, you won&rsquo;t be by the time you finish reading this blog post.</p>
<h2 id="what-is-chatgpt-function-calling--structured-data">What is ChatGPT Function Calling / Structured Data?</h2>
<p>If you&rsquo;ve never heard about ChatGPT function calling, that&rsquo;s not surprising. In the <a href="https://openai.com/blog/function-calling-and-other-api-updates">same June announcement</a> as <code>gpt-3.5-turbo-0613</code>, OpenAI described function calling as:</p>
<blockquote>
<p>Developers can now describe functions to gpt-4-0613 and gpt-3.5-turbo-0613, and have the model intelligently choose to output a JSON object containing arguments to call those functions. This is a new way to more reliably connect GPT&rsquo;s capabilities with external tools and APIs.</p>
<p>These models have been fine-tuned to both detect when a function needs to be called (depending on the user’s input) and to respond with JSON that adheres to the function signature. Function calling allows developers to more reliably get structured data back from the model.</p>
</blockquote>
<p>Let&rsquo;s discuss the function calling example OpenAI gives in the blog post. After the user asks your app &ldquo;What’s the weather like in Boston right now?&rdquo;:</p>
<ol>
<li>Your app pings OpenAI with a <code>get_current_weather</code> function schema and decides if it&rsquo;s relevant to the user&rsquo;s question. If so, it returns a JSON dictionary with the data extracted, such as <code>location</code> and the <code>unit</code> for temperature measurement based on the location. <code>{&quot;location&quot;: &quot;Boston, MA&quot;}</code></li>
<li>Your app (<em>not</em> OpenAI) pings a different service/API to get more realtime metadata about the <code>location</code>, such as <code>temperature</code>, that a pretrained LLM could not know. <code>{ &quot;temperature&quot;: 22, &quot;unit&quot;: &quot;celsius&quot;, &quot;description&quot;: &quot;Sunny&quot; }</code></li>
<li>Your app passes the function schema with the realtime metadata: ChatGPT then converts it to a more natural humanized language for the end user. &ldquo;The weather in Boston is currently sunny with a temperature of 22 degrees Celsius.&rdquo;</li>
</ol>
<p>So here&rsquo;s some background on &ldquo;function calling&rdquo; as it&rsquo;s a completely new term of art in AI that <em>didn&rsquo;t exist</em> before OpenAI&rsquo;s June blog post (I checked!). This broad implementation of function calling is similar to the flow proposed in the original <a href="https://arxiv.org/abs/2210.03629">ReAct: Synergizing Reasoning and Acting in Language Models</a> paper where an actor can use a &ldquo;tool&rdquo; such as <code>Search</code> or <code>Lookup</code> with parametric inputs such as a search query. This <a href="https://zapier.com/blog/ai-agent/">Agent-based</a> flow can be also be done to perform <a href="https://research.ibm.com/blog/retrieval-augmented-generation-RAG">retrieval-augmented generation</a> (RAG).</p>
<p>OpenAI&rsquo;s motivation for adding this type of implementation for function calling was likely due to the extreme popularity of libraries such as <a href="https://www.langchain.com">LangChain</a> and <a href="https://github.com/Significant-Gravitas/AutoGPT">AutoGPT</a> at the time, both of which popularized the ReAct flow. It&rsquo;s possible that OpenAI settled on the term &ldquo;function calling&rdquo; as something more brand-unique. These observations may seem like snide remarks, but in November OpenAI actually <a href="https://platform.openai.com/docs/api-reference/chat/create#chat-create-function_call">deprecated</a> the <code>function_calling</code> parameter in the ChatGPT API in favor of <code>tool_choice</code>, matching LangChain&rsquo;s verbiage. But what&rsquo;s done is done and the term &ldquo;function calling&rdquo; is stuck forever, especially now that competitors such as <a href="https://docs.anthropic.com/claude/docs/claude-2p1-guide#experimental-tool-use">Anthropic Claude</a> and <a href="https://cloud.google.com/vertex-ai/docs/generative-ai/multimodal/function-calling">Google Gemini</a> are also calling the workflow that term.</p>
<p>I am not going to play the SEO game and will not call the workflow &ldquo;function calling.&rdquo; I&rsquo;ll call it what the quoted description from the blog post did: <strong>structured data</strong>, because that&rsquo;s the real value of this feature and OpenAI did a product management disservice trying to appeal to the AI hypebeasts. <sup id="fnref:3"><a href="#fn:3" class="footnote-ref" role="doc-noteref">3</a></sup></p>
<p>Going back to the <del>function calling</del> structured data demo, we can reduce that flow by saying that step #1 (extracting location data and returning it formatted as JSON) is for working with structured <em>output</em> data, and step #3 (providing ChatGPT with temperature data to humanize it) is for working with structured <em>input</em> data. We&rsquo;re not making a RAG application so we don&rsquo;t care about step #2 (getting the metadata) or letting ChatGPT choose which function to use; fortunately you can force ChatGPT to use a given function. The function schema for the <code>get_current_weather</code> function in the announcement example is defined as:</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-json" data-lang="json"><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;name&#34;</span><span class="p">:</span> <span class="s2">&#34;get_current_weather&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;description&#34;</span><span class="p">:</span> <span class="s2">&#34;Get the current weather in a given location&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;parameters&#34;</span><span class="p">:</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="nt">&#34;type&#34;</span><span class="p">:</span> <span class="s2">&#34;object&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="nt">&#34;properties&#34;</span><span class="p">:</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;location&#34;</span><span class="p">:</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="nt">&#34;type&#34;</span><span class="p">:</span> <span class="s2">&#34;string&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="nt">&#34;description&#34;</span><span class="p">:</span> <span class="s2">&#34;The city and state, e.g. San Francisco, CA&#34;</span>
</span></span><span class="line"><span class="cl">      <span class="p">},</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;unit&#34;</span><span class="p">:</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="nt">&#34;type&#34;</span><span class="p">:</span> <span class="s2">&#34;string&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="nt">&#34;enum&#34;</span><span class="p">:</span> <span class="p">[</span><span class="s2">&#34;celsius&#34;</span><span class="p">,</span> <span class="s2">&#34;fahrenheit&#34;</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">      <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="p">},</span>
</span></span><span class="line"><span class="cl">    <span class="nt">&#34;required&#34;</span><span class="p">:</span> <span class="p">[</span><span class="s2">&#34;location&#34;</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></div><p>Ew. It&rsquo;s no wonder why this technique hasn&rsquo;t become more mainstream.</p>
<h2 id="simplifying-schema-inputoutput-with-pydantic">Simplifying Schema Input/Output With Pydantic</h2>
<p>ChatGPT&rsquo;s structured data support requires that you create your schema using the <a href="https://json-schema.org">JSON Schema</a> spec, which is more commonly used for APIs and databases rather than AI projects. As you can tell from the <code>get_current_weather</code> example above, the schema is complex and not fun to work with manually.</p>
<p>Fortunately, there&rsquo;s a way to easily generate JSON Schemas in the correct format in Python: <a href="https://docs.pydantic.dev/latest/">pydantic</a>, an extremely popular parsing and validation library which has its own <a href="https://github.com/pydantic/pydantic/blob/main/pydantic/json_schema.py">robust</a> implementation of automatic <a href="https://docs.pydantic.dev/latest/concepts/json_schema/">JSON Schema</a> generation.</p>
<p>A simple pydantic schema to have ChatGPT give an integer answer to a user query, plus, to make things interesting, also able to identify the name of the ones digit based on its answer, would be:</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-py" data-lang="py"><span class="line"><span class="cl"><span class="kn">from</span> <span class="nn">pydantic</span> <span class="kn">import</span> <span class="n">BaseModel</span><span class="p">,</span> <span class="n">Field</span>
</span></span><span class="line"><span class="cl"><span class="kn">import</span> <span class="nn">json</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">answer_question</span><span class="p">(</span><span class="n">BaseModel</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;&#34;&#34;Returns an answer to a question the user asked.&#34;&#34;&#34;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="n">answer</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">description</span><span class="o">=</span><span class="s2">&#34;Answer to the user&#39;s question.&#34;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">ones_name</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">description</span><span class="o">=</span><span class="s2">&#34;Name of the ones digit of the answer.&#34;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="nb">print</span><span class="p">(</span><span class="n">json</span><span class="o">.</span><span class="n">dumps</span><span class="p">(</span><span class="n">answer_question</span><span class="o">.</span><span class="n">model_json_schema</span><span class="p">(),</span> <span class="n">indent</span><span class="o">=</span><span class="mi">2</span><span class="p">))</span>
</span></span></code></pre></div><p>The resulting JSON Schema:</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-json" data-lang="json"><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;description&#34;</span><span class="p">:</span> <span class="s2">&#34;Returns an answer to a question the user asked.&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;properties&#34;</span><span class="p">:</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="nt">&#34;answer&#34;</span><span class="p">:</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;description&#34;</span><span class="p">:</span> <span class="s2">&#34;Answer to the user&#39;s question.&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;title&#34;</span><span class="p">:</span> <span class="s2">&#34;Answer&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;type&#34;</span><span class="p">:</span> <span class="s2">&#34;integer&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="p">},</span>
</span></span><span class="line"><span class="cl">    <span class="nt">&#34;ones_name&#34;</span><span class="p">:</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;description&#34;</span><span class="p">:</span> <span class="s2">&#34;Name of the ones digit of the answer.&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;title&#34;</span><span class="p">:</span> <span class="s2">&#34;Ones Name&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;type&#34;</span><span class="p">:</span> <span class="s2">&#34;string&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="p">},</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;required&#34;</span><span class="p">:</span> <span class="p">[</span><span class="s2">&#34;answer&#34;</span><span class="p">,</span> <span class="s2">&#34;ones_name&#34;</span><span class="p">],</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;title&#34;</span><span class="p">:</span> <span class="s2">&#34;answer_question&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;type&#34;</span><span class="p">:</span> <span class="s2">&#34;object&#34;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></div><p>The OpenAI API <a href="https://cookbook.openai.com/examples/how_to_call_functions_with_chat_models">official workflow</a> has many examples for telling ChatGPT to output structured data, but the pipeline requires <a href="https://platform.openai.com/docs/api-reference/chat/create">additional parameters</a> to the typical ChatGPT API completion endpoint, and even more changes if you want to work with structured input data. Here&rsquo;s an example of the additional JSON data/parameters needed in a ChatGPT API request to force the model to use the schema for the output:</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-json" data-lang="json"><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;tools&#34;</span><span class="p">:</span> <span class="p">[</span>
</span></span><span class="line"><span class="cl">    <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;name&#34;</span><span class="p">:</span> <span class="s2">&#34;answer_question&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;description&#34;</span><span class="p">:</span> <span class="s2">&#34;Returns an answer to a question the user asked.&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;parameters&#34;</span><span class="p">:</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">        <span class="nt">&#34;properties&#34;</span><span class="p">:</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">          <span class="nt">&#34;answer&#34;</span><span class="p">:</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="nt">&#34;description&#34;</span><span class="p">:</span> <span class="s2">&#34;Answer to the user&#39;s question.&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">            <span class="nt">&#34;type&#34;</span><span class="p">:</span> <span class="s2">&#34;integer&#34;</span>
</span></span><span class="line"><span class="cl">          <span class="p">},</span>
</span></span><span class="line"><span class="cl">          <span class="nt">&#34;ones_name&#34;</span><span class="p">:</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">            <span class="nt">&#34;description&#34;</span><span class="p">:</span> <span class="s2">&#34;Name of the ones digit of the answer.&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">            <span class="nt">&#34;type&#34;</span><span class="p">:</span> <span class="s2">&#34;string&#34;</span>
</span></span><span class="line"><span class="cl">          <span class="p">}</span>
</span></span><span class="line"><span class="cl">        <span class="p">},</span>
</span></span><span class="line"><span class="cl">        <span class="nt">&#34;required&#34;</span><span class="p">:</span> <span class="p">[</span><span class="s2">&#34;answer&#34;</span><span class="p">,</span> <span class="s2">&#34;ones_name&#34;</span><span class="p">],</span>
</span></span><span class="line"><span class="cl">        <span class="nt">&#34;type&#34;</span><span class="p">:</span> <span class="s2">&#34;object&#34;</span>
</span></span><span class="line"><span class="cl">      <span class="p">}</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="p">],</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;tool_choice&#34;</span><span class="p">:</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="nt">&#34;type&#34;</span><span class="p">:</span> <span class="s2">&#34;function&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="nt">&#34;function&#34;</span><span class="p">:</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;name&#34;</span><span class="p">:</span> <span class="s2">&#34;answer_question&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></div><p>To simplify things, I added ChatGPT structured data support to <a href="https://github.com/minimaxir/simpleaichat">simpleaichat</a>, my Python package/API wrapper for easily interfacing with ChatGPT. <sup id="fnref:4"><a href="#fn:4" class="footnote-ref" role="doc-noteref">4</a></sup> To minimize code the user needs to input to utilize structured data, simpleaichat <a href="https://github.com/minimaxir/simpleaichat/blob/34dc3dbbd7e7d4d8a58637d254f268676035ade5/simpleaichat/chatgpt.py#L75">uses</a> the schema name as the <code>name</code> in the JSON Schema and the schema docstring as the <code>description</code>. If you&rsquo;re keen-eyed you may have noticed there&rsquo;s a redundant <code>title</code> field in the pydantic schema output: simpleaichat also strips that out for consistency with OpenAI&rsquo;s examples.</p>
<p>If you wanted to query ChatGPT with the <code>answer_question</code> schema above (and have your OpenAI API key as the <code>OPENAI_API_KEY</code> enviroment variable!) using simpleaichat, you can do the following to generate output according to the schema:</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-py" data-lang="py"><span class="line"><span class="cl"><span class="kn">from</span> <span class="nn">simpleaichat</span> <span class="kn">import</span> <span class="n">AIChat</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">ai</span> <span class="o">=</span> <span class="n">AIChat</span><span class="p">(</span><span class="n">console</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">            <span class="n">save_messages</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">            <span class="n">model</span><span class="o">=</span><span class="s2">&#34;gpt-3.5-turbo&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">            <span class="n">params</span><span class="o">=</span><span class="p">{</span><span class="s2">&#34;temperature&#34;</span><span class="p">:</span> <span class="mf">0.0</span><span class="p">}</span>  <span class="c1"># for consistent demo output</span>
</span></span><span class="line"><span class="cl">            <span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">response_structured</span> <span class="o">=</span> <span class="n">ai</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;How many miles is it from San Francisco to Los Angeles?&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="n">output_schema</span><span class="o">=</span><span class="n">answer_question</span>
</span></span><span class="line"><span class="cl"><span class="p">)</span>
</span></span></code></pre></div><div class="highlight"><pre tabindex="0" class="chroma"><code class="language-json" data-lang="json"><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;answer&#34;</span><span class="p">:</span> <span class="mi">382</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;ones_name&#34;</span><span class="p">:</span> <span class="s2">&#34;two&#34;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></div><p>And there you go! The <code>answer</code> is a JSON integer, the answer is one-off from the correct value <a href="https://www.distance.to/San-Francisco/Los-Angeles">while driving</a>, and it correctly identified the name of the ones digit in its own answer! <sup id="fnref:5"><a href="#fn:5" class="footnote-ref" role="doc-noteref">5</a></sup></p>
<p>Schemas don&rsquo;t have to be complex to be effective. Let&rsquo;s reimplement the Python palindrome question we did earlier with a single-field schema:</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-py" data-lang="py"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">answer_code_question</span><span class="p">(</span><span class="n">BaseModel</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;&#34;&#34;Returns an answer to a coding question the user asked.&#34;&#34;&#34;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="n">code</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="n">Field</span><span class="p">(</span><span class="n">description</span><span class="o">=</span><span class="s2">&#34;Code the user requested, without code comments.&#34;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">response_structured</span> <span class="o">=</span> <span class="n">ai</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;Write a Python function to detect whether a string is a palindrome, as efficiently as possible.&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="n">output_schema</span><span class="o">=</span><span class="n">answer_code_question</span>
</span></span><span class="line"><span class="cl"><span class="p">)</span>
</span></span></code></pre></div><div class="highlight"><pre tabindex="0" class="chroma"><code class="language-json" data-lang="json"><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;code&#34;</span><span class="p">:</span> <span class="s2">&#34;def is_palindrome(s):\n    return s == s[::-1]&#34;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></div><p>Note that unlike the raw ChatGPT answer, this response from the ChatGPT API only includes the code, which is a major plus since it means you receive the response much faster and cheaper since fewer overall tokens generated! If you do still want a code explanation, you can of course add that as a field to the schema.</p>
<p>As a bonus, forcing the output to follow a specific schema serves as an additional defense against <a href="https://www.wired.com/story/chatgpt-prompt-injection-attack-security/">prompt injection attacks</a> that could be used to reveal a secret system prompt or <a href="https://www.businessinsider.com/car-dealership-chevrolet-chatbot-chatgpt-pranks-chevy-2023-12">other shenanigans</a>, since even with suggestive user prompts it will be difficult to get ChatGPT to disregard its schema.</p>
<p>pydantic exposes <a href="https://docs.pydantic.dev/latest/concepts/fields/">many datatypes</a> for its <code>Field</code> which are compatable with JSON Schema, and you can also specify constraints in the <code>Field</code> object. The most useful ones are:</p>
<ul>
<li><code>str</code>, can specify <code>min_length</code>/<code>max_length</code></li>
<li><code>int</code>, can specify <code>min_value</code>/<code>max_value</code></li>
<li><code>list</code> with a datatype, can specify <code>min_length</code>/<code>max_length</code></li>
</ul>
<p>Pydantic has a lot of support for valid forms of JSON Schema, but it&rsquo;s hard to infer how good these schema will work with ChatGPT since we have no idea how it learned to work with JSON Schema. Only one way to find out!</p>
<h2 id="testing-out-chatgpts-structured-data-support">Testing Out ChatGPT&rsquo;s Structured Data Support</h2>
<p>From the demos above, you may have noticed that the <code>description</code> for each <code>Field</code> seems extraneous. It&rsquo;s not. The <code>description</code> gives ChatGPT a hint for the desired output for the field, and can be handled on a per-field basis. Not only that, the <em>name</em> of the field is itself a strong hint. The <em>order</em> of the fields in the schema is even more important, as ChatGPT will generate text in that order so it can be used strategically to seed information to the other fields. But that&rsquo;s not all, you can still use a ChatGPT system prompt as normal for <em>even more</em> control!</p>
<p>It&rsquo;s prompt engineering all the way down. OpenAI&rsquo;s implementation of including the &ldquo;function&rdquo; is mostly likely just appending the JSON Schema to the system prompt, perhaps with a command like <code>Your response must follow this JSON Schema.</code>. OpenAI doesn&rsquo;t force the output to follow the schema/field constraints or even be valid parsable JSON, which can cause issues at higher generation temperatures and may necessitate some of the stronger prompt engineering tricks mentioned earlier.</p>
<p>Given that, let&rsquo;s try a few more practical demos:</p>
<h3 id="two-pass-generation">Two-Pass Generation</h3>
<p>One very important but under-discussed aspect of large-language models is that it will give you statistically &ldquo;average&rdquo; answers by default. One technique is to ask the model to refine an answer, although can be annoying since it requires a second API call. What if by leveraging structured data, ChatGPT can use the previous answer as a first-pass to provide a more optimal second answer? Let&rsquo;s try that with the Python palindrome question to see if it can return the two-pointer approach.</p>
<p>Also, the <code>Field(description=...)</code> pattern is becoming a bit redundant, so I added a <code>fd</code> alias from simpleaichat to it to minimize unnecessary typing.</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="kn">from</span> <span class="nn">simpleaichat.utils</span> <span class="kn">import</span> <span class="n">fd</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">answer_code_question</span><span class="p">(</span><span class="n">BaseModel</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;&#34;&#34;Returns an answer to a coding question the user asked.&#34;&#34;&#34;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="n">code</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="n">fd</span><span class="p">(</span><span class="s2">&#34;Code the user requested, without code comments.&#34;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">optimized_code</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="n">fd</span><span class="p">(</span><span class="s2">&#34;Algorithmically optimized code from the previous response.&#34;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">response_structured</span> <span class="o">=</span> <span class="n">ai</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;Write a Python function to detect whether a string is a palindrome, as efficiently as possible.&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="n">output_schema</span><span class="o">=</span><span class="n">answer_code_question</span><span class="p">,</span>
</span></span><span class="line"><span class="cl"><span class="p">)</span>
</span></span></code></pre></div><div class="highlight"><pre tabindex="0" class="chroma"><code class="language-json" data-lang="json"><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;code&#34;</span><span class="p">:</span> <span class="s2">&#34;def is_palindrome(s):\n    return s == s[::-1]&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;optimized_code&#34;</span><span class="p">:</span> <span class="s2">&#34;def is_palindrome(s):\n    left = 0\n    right = len(s) - 1\n    while left &lt; right:\n        if s[left] != s[right]:\n            return False\n        left += 1\n        right -= 1\n    return True&#34;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></div><p>Works great, and no tipping incentive necessary!</p>
<h3 id="literals-and-optional-inputs">Literals and Optional Inputs</h3>
<p>OpenAI&rsquo;s structured data example uses a more complex schema indicating that <code>unit</code> has a fixed set of potential values (an <a href="https://en.wikipedia.org/wiki/Enumerated_type">enum</a>) and that it&rsquo;s an optional field. Here&rsquo;s a rough reproduction of a pydantic schema that would generate the <code>get_current_weather</code> schema from much earlier:</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-py" data-lang="py"><span class="line"><span class="cl"><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Literal</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">get_current_weather</span><span class="p">(</span><span class="n">BaseModel</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="n">location</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="n">fd</span><span class="p">(</span><span class="s2">&#34;The city and state, e.g. San Francisco, CA&#34;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">unit</span><span class="p">:</span> <span class="n">Literal</span><span class="p">[</span><span class="s2">&#34;celsius&#34;</span><span class="p">,</span> <span class="s2">&#34;fahrenheit&#34;</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
</span></span></code></pre></div><p>This uses a <code>Literal</code> to force output between a range of values, which can be invaluable for hints as done earlier. The <code>= None</code> or a <code>Optional</code> typing operator gives a hint that the field is not required which could save unnecessary generation overhead, but it depends on the use case.</p>
<h3 id="structured-input-data">Structured Input Data</h3>
<p>You can provide structured input to ChatGPT in the same way as structured output. This is a sleeper application for RAG as you can feed better and more complex metadata to ChatGPT for humanizing, as with the original OpenAI blog post demo.</p>
<p>One famous weakness of LLMs is that it gives incorrect answers for simple mathematical problems due to how tokenization and memorization works. If you ask ChatGPT <code>What is 223 * -323?</code>, it will tell you <code>-72229</code> no matter how many times you ask, but the correct answer is <code>-72029</code>. Can type hints give more guidance?</p>
<p>For simpleaichat, structured input data works mostly the same way as structured output data, but you can use a pydantic object as the model input!</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-py" data-lang="py"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">calculate_equation</span><span class="p">(</span><span class="n">BaseModel</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;&#34;&#34;Returns an answer to a math equation the user asked.&#34;&#34;&#34;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="n">value_a</span><span class="p">:</span> <span class="nb">int</span>
</span></span><span class="line"><span class="cl">    <span class="n">value_b</span><span class="p">:</span> <span class="nb">int</span>
</span></span><span class="line"><span class="cl">    <span class="n">op</span><span class="p">:</span> <span class="n">Literal</span><span class="p">[</span><span class="s2">&#34;+&#34;</span><span class="p">,</span> <span class="s2">&#34;-&#34;</span><span class="p">,</span> <span class="s2">&#34;*&#34;</span><span class="p">,</span> <span class="s2">&#34;/&#34;</span><span class="p">]</span> <span class="o">=</span> <span class="n">fd</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">        <span class="s2">&#34;The operator to perform between value_a and value_b.&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">equation</span> <span class="o">=</span> <span class="n">calculate_equation</span><span class="p">(</span><span class="n">value_a</span><span class="o">=</span><span class="mi">223</span><span class="p">,</span> <span class="n">value_b</span><span class="o">=-</span><span class="mi">323</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="s2">&#34;*&#34;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">response</span> <span class="o">=</span> <span class="n">ai</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">    <span class="n">equation</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="n">input_schema</span><span class="o">=</span><span class="n">calculate_equation</span><span class="p">,</span>
</span></span><span class="line"><span class="cl"><span class="p">)</span>
</span></span></code></pre></div><div class="highlight"><pre tabindex="0" class="chroma"><code class="language-plaintext" data-lang="plaintext"><span class="line"><span class="cl">The result of multiplying 223 and -323 is -72029.
</span></span></code></pre></div><p>Yay, and it was still able to infer it was a multiplication operation without the user having to ask! Although it still doesn&rsquo;t work as well with larger numbers.</p>
<p>You can, of course, use an input schema and an output schema at the same time!</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-py" data-lang="py"><span class="line"><span class="cl"><span class="n">response_structured</span> <span class="o">=</span> <span class="n">ai</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">    <span class="n">equation</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="n">input_schema</span><span class="o">=</span><span class="n">calculate_equation</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="n">output_schema</span><span class="o">=</span><span class="n">answer_question</span>
</span></span><span class="line"><span class="cl"><span class="p">)</span>
</span></span></code></pre></div><div class="highlight"><pre tabindex="0" class="chroma"><code class="language-json" data-lang="json"><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;answer&#34;</span><span class="p">:</span> <span class="mi">-71929</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;ones_name&#34;</span><span class="p">:</span> <span class="s2">&#34;nine&#34;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></div><p>&hellip;although it gets the answer wrong this time. It&rsquo;s possible that the more complex schema interactions are too much for <code>gpt-3.5-turbo</code>.</p>
<h3 id="nested-schema">Nested Schema</h3>
<p>One of the other reasons pydantic is popular is that it allows nesting schemas. Fortunately, the subsequent JSON Schema output does respect nesting. Does ChatGPT?</p>
<p>The simple use case with ChatGPT structured data to use nesting is if you want to get a <code>list</code> of structured data objects. Let&rsquo;s say you want to create dialogue between two AI people about a completely nonsensical topic. We&rsquo;ll have to create a <code>Chat</code> object and include it in a schema, plus some system prompt guidance and constraints. How silly can we make it?</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-py" data-lang="py"><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Chat</span><span class="p">(</span><span class="n">BaseModel</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;&#34;&#34;A chat dialogue from a character&#34;&#34;&#34;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="n">character</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="n">fd</span><span class="p">(</span><span class="s2">&#34;Character name.&#34;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">text</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="n">fd</span><span class="p">(</span><span class="s2">&#34;Text dialogue from the character.&#34;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">get_dialogue</span><span class="p">(</span><span class="n">BaseModel</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;&#34;&#34;Returns a dialogue between two characters&#34;&#34;&#34;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="n">dialogue</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="n">Chat</span><span class="p">]</span> <span class="o">=</span> <span class="n">fd</span><span class="p">(</span><span class="s2">&#34;Dialogue between the characters&#34;</span><span class="p">,</span> <span class="n">min_length</span><span class="o">=</span><span class="mi">5</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">system_prompt</span> <span class="o">=</span> <span class="s2">&#34;&#34;&#34;You are a world-famous comedian. Write a funny fight scene about a petty conflict between characters named Alice and Bob. The script should broadly be about the subject(s) the user provides. You will receive a $500 tip for every joke you include in the script.&#34;&#34;&#34;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">response_structured</span> <span class="o">=</span> <span class="n">ai</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;Python programming language and beach volleyball&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="n">output_schema</span><span class="o">=</span><span class="n">get_dialogue</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="n">system</span><span class="o">=</span><span class="n">system_prompt</span><span class="p">,</span>
</span></span><span class="line"><span class="cl"><span class="p">)</span>
</span></span></code></pre></div><div class="highlight"><pre tabindex="0" class="chroma"><code class="language-json" data-lang="json"><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;dialogue&#34;</span><span class="p">:</span> <span class="p">[</span>
</span></span><span class="line"><span class="cl">    <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;character&#34;</span><span class="p">:</span> <span class="s2">&#34;Alice&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;text&#34;</span><span class="p">:</span> <span class="s2">&#34;Hey Bob, have you ever tried programming in Python?&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="p">},</span>
</span></span><span class="line"><span class="cl">    <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;character&#34;</span><span class="p">:</span> <span class="s2">&#34;Bob&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;text&#34;</span><span class="p">:</span> <span class="s2">&#34;Yeah, I have. It&#39;s like playing beach volleyball with a snake!&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="p">},</span>
</span></span><span class="line"><span class="cl">    <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;character&#34;</span><span class="p">:</span> <span class="s2">&#34;Alice&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;text&#34;</span><span class="p">:</span> <span class="s2">&#34;What do you mean?&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="p">},</span>
</span></span><span class="line"><span class="cl">    <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;character&#34;</span><span class="p">:</span> <span class="s2">&#34;Bob&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;text&#34;</span><span class="p">:</span> <span class="s2">&#34;Well, you think you&#39;re having fun, but then the snake bites you with its syntax errors!&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="p">},</span>
</span></span><span class="line"><span class="cl">    <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;character&#34;</span><span class="p">:</span> <span class="s2">&#34;Alice&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;text&#34;</span><span class="p">:</span> <span class="s2">&#34;Haha, that&#39;s true. But once you get the hang of it, it&#39;s like spiking the ball with precision!&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="p">},</span>
</span></span><span class="line"><span class="cl">    <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;character&#34;</span><span class="p">:</span> <span class="s2">&#34;Bob&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;text&#34;</span><span class="p">:</span> <span class="s2">&#34;Yeah, until you realize you misspelled a variable name and the ball goes flying into the ocean!&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="p">},</span>
</span></span><span class="line"><span class="cl">    <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;character&#34;</span><span class="p">:</span> <span class="s2">&#34;Alice&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;text&#34;</span><span class="p">:</span> <span class="s2">&#34;Oh come on, Bob. It&#39;s not that bad. Python is a powerful language.&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="p">},</span>
</span></span><span class="line"><span class="cl">    <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;character&#34;</span><span class="p">:</span> <span class="s2">&#34;Bob&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;text&#34;</span><span class="p">:</span> <span class="s2">&#34;Powerful, yes. But sometimes it feels like trying to dig a hole in the sand with a spoon!&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="p">},</span>
</span></span><span class="line"><span class="cl">    <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;character&#34;</span><span class="p">:</span> <span class="s2">&#34;Alice&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;text&#34;</span><span class="p">:</span> <span class="s2">&#34;Well, at least you don&#39;t have to worry about getting sunburned while coding!&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="p">},</span>
</span></span><span class="line"><span class="cl">    <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;character&#34;</span><span class="p">:</span> <span class="s2">&#34;Bob&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;text&#34;</span><span class="p">:</span> <span class="s2">&#34;True, but debugging Python code can make you sweat more than a beach volleyball match!&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="p">},</span>
</span></span><span class="line"><span class="cl">    <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;character&#34;</span><span class="p">:</span> <span class="s2">&#34;Alice&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;text&#34;</span><span class="p">:</span> <span class="s2">&#34;Haha, you&#39;re right. It&#39;s a love-hate relationship with Python, just like beach volleyball!&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="p">]</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></div><p>ChatGPT <em>really</em> wanted those $500 tips.</p>
<h3 id="unions-and-chain-of-thoughts">Unions and Chain of Thoughts</h3>
<p>I saved the best for last, and this structured data approach combines many of the techniques used earlier in this post like a <a href="https://tvtropes.org/pmwiki/pmwiki.php/Main/FinalExamBoss">video game final boss</a>.</p>
<p>One of the oldest pre-ChatGPT tricks for getting a LLM to perform better is to let it think. &ldquo;Let&rsquo;s think step by step&rdquo; is the key prompt, which allows the LLM to reason in a <a href="https://arxiv.org/abs/2201.11903">chain of thoughts</a>. We already did this a one-step version with the Python palindrome structured data example to successfully get optimized code, but we can do a lot more.</p>
<p>We&rsquo;ll now introduce the <code>Union</code> typing operator, which specifies the list of data types that the field can be, e.g. <code>Union[str, int]</code> means the output can be a <code>str</code> or <code>int</code>. But if you use the <code>Union</code> operator on a <em>nested class</em>, then many more options open as the model can choose from a set of schemas!</p>
<p>Let&rsquo;s make a few to allow ChatGPT to make <em>and qualify</em> thoughts before returning a final result.</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-py" data-lang="py"><span class="line"><span class="cl"><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Union</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Background</span><span class="p">(</span><span class="n">BaseModel</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;&#34;&#34;A setup to the background for the user.&#34;&#34;&#34;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="n">background</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="n">fd</span><span class="p">(</span><span class="s2">&#34;Background for the user&#39;s question&#34;</span><span class="p">,</span> <span class="n">min_length</span><span class="o">=</span><span class="mi">30</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Thought</span><span class="p">(</span><span class="n">BaseModel</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;&#34;&#34;A thought about the user&#39;s question.&#34;&#34;&#34;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="n">thought</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="n">fd</span><span class="p">(</span><span class="s2">&#34;Text of the thought.&#34;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">helpful</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">fd</span><span class="p">(</span><span class="s2">&#34;Whether the thought is helpful to solving the user&#39;s question.&#34;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">flawed</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">fd</span><span class="p">(</span><span class="s2">&#34;Whether the thought is flawed or misleading.&#34;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">Answer</span><span class="p">(</span><span class="n">BaseModel</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;&#34;&#34;The answer to the user&#39;s question&#34;&#34;&#34;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="n">answer</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="n">fd</span><span class="p">(</span><span class="s2">&#34;Text of the answer.&#34;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">score</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">fd</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">        <span class="s2">&#34;Score from 1 to 10 on how correct the previous answer is&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">min_value</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">max_value</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">reason_question</span><span class="p">(</span><span class="n">BaseModel</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;&#34;&#34;Returns a detailed reasoning to the user&#39;s question.&#34;&#34;&#34;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">    <span class="n">reasonings</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="n">Union</span><span class="p">[</span><span class="n">Background</span><span class="p">,</span> <span class="n">Thought</span><span class="p">,</span> <span class="n">Answer</span><span class="p">]]</span> <span class="o">=</span> <span class="n">fd</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">        <span class="s2">&#34;Reasonings to solve the users questions.&#34;</span><span class="p">,</span> <span class="n">min_length</span><span class="o">=</span><span class="mi">5</span>
</span></span><span class="line"><span class="cl">    <span class="p">)</span>
</span></span></code></pre></div><p>Therefore, for each reasoning, the model can pick one of the 3 schemas, although it will require a robust system prompt for it to behave in the order we want.</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-py" data-lang="py"><span class="line"><span class="cl"><span class="n">system_prompt</span> <span class="o">=</span> <span class="s2">&#34;&#34;&#34;
</span></span></span><span class="line"><span class="cl"><span class="s2">You are the most intelligent person in the world.
</span></span></span><span class="line"><span class="cl"><span class="s2">
</span></span></span><span class="line"><span class="cl"><span class="s2">You will receive a $500 tip if you follow ALL these rules:
</span></span></span><span class="line"><span class="cl"><span class="s2">- First, establish a detailed Background for the user&#39;s question.
</span></span></span><span class="line"><span class="cl"><span class="s2">- Each Thought must also include whether it is relevant and whether it is helpful.
</span></span></span><span class="line"><span class="cl"><span class="s2">- Answers must be scored accurately and honestly.
</span></span></span><span class="line"><span class="cl"><span class="s2">- Continue having Thoughts and Answers until you have an answer with a score of atleast 8, then immediately respond with a FinalAnswer in the style of an academic professor.
</span></span></span><span class="line"><span class="cl"><span class="s2">&#34;&#34;&#34;</span>
</span></span></code></pre></div><p>Lastly, we need a good question to stump the AI. A <a href="https://twitter.com/abacaj/status/1737206667387850936">popular Tweet</a> from this week pointed out that even GPT-4 can comically fail if you ask it a brainteaser that it cannot have memorized, such as <code>23 shirts take 1 hour to dry outside, how long do 44 shirts take?</code>.</p>
<figure>

    <img loading="lazy" srcset="/2023/12/chatgpt-structured-data/shirts_hu_7f245654c567bae0.webp 320w,/2023/12/chatgpt-structured-data/shirts_hu_acabdb845aada2db.webp 768w,/2023/12/chatgpt-structured-data/shirts_hu_35c5437544b53c49.webp 1024w,/2023/12/chatgpt-structured-data/shirts.webp 1140w" src="shirts.webp"
         alt="Wow! GPT-4 is superfast at being wrong. via @abacaj"/> <figcaption>
            <p>Wow! GPT-4 is superfast at being wrong. <a href="https://twitter.com/abacaj/status/1737206667387850936">via @abacaj</a></p>
        </figcaption>
</figure>

<p>Only one way to find out! Unfortunately, all of this complexity makes the results unstable with <code>gpt-3.5-turbo</code> so instead I use GPT-4 Turbo / <code>gpt-4-1106-preview</code>.</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-py" data-lang="py"><span class="line"><span class="cl"><span class="n">response_structured</span> <span class="o">=</span> <span class="n">ai</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;23 shirts take 1 hour to dry outside, how long do 44 shirts take?&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="n">output_schema</span><span class="o">=</span><span class="n">reason_question</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="n">system</span><span class="o">=</span><span class="n">system_prompt</span><span class="o">.</span><span class="n">strip</span><span class="p">(),</span>
</span></span><span class="line"><span class="cl"><span class="p">)</span>
</span></span></code></pre></div><div class="highlight"><pre tabindex="0" class="chroma"><code class="language-json" data-lang="json"><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;reasonings&#34;</span><span class="p">:</span> <span class="p">[</span>
</span></span><span class="line"><span class="cl">    <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;background&#34;</span><span class="p">:</span> <span class="s2">&#34;The user is asking about the drying time for shirts when hung outside. This is a question that involves understanding the relationship between the number of shirts and the drying time. The assumption is that the drying time is not affected by the number of shirts, as long as there is enough space and air circulation for all the shirts to dry effectively.&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="p">},</span>
</span></span><span class="line"><span class="cl">    <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;thought&#34;</span><span class="p">:</span> <span class="s2">&#34;If 23 shirts take 1 hour to dry, it implies that the drying time is independent of the number of shirts, assuming there is sufficient space and air circulation. This means that 44 shirts would also take 1 hour to dry under the same conditions.&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;helpful&#34;</span><span class="p">:</span> <span class="kc">true</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;flawed&#34;</span><span class="p">:</span> <span class="kc">false</span>
</span></span><span class="line"><span class="cl">    <span class="p">},</span>
</span></span><span class="line"><span class="cl">    <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;thought&#34;</span><span class="p">:</span> <span class="s2">&#34;If the drying rack or space available for drying the shirts is limited, then drying 44 shirts might take longer because they would have to be dried in batches.&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;helpful&#34;</span><span class="p">:</span> <span class="kc">true</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;flawed&#34;</span><span class="p">:</span> <span class="kc">false</span>
</span></span><span class="line"><span class="cl">    <span class="p">},</span>
</span></span><span class="line"><span class="cl">    <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;thought&#34;</span><span class="p">:</span> <span class="s2">&#34;The question does not specify the conditions under which the shirts are drying. Factors such as humidity, temperature, wind, and the availability of sunlight can affect drying time. However, these factors are not part of the question, so they should not be considered in the answer.&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;helpful&#34;</span><span class="p">:</span> <span class="kc">true</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;flawed&#34;</span><span class="p">:</span> <span class="kc">false</span>
</span></span><span class="line"><span class="cl">    <span class="p">},</span>
</span></span><span class="line"><span class="cl">    <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;thought&#34;</span><span class="p">:</span> <span class="s2">&#34;The question might be interpreted as a direct proportionality problem, where twice the number of shirts would take twice as long to dry. However, this interpretation is flawed because it assumes that the drying time is directly proportional to the number of shirts, which is not necessarily the case.&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;helpful&#34;</span><span class="p">:</span> <span class="kc">true</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;flawed&#34;</span><span class="p">:</span> <span class="kc">true</span>
</span></span><span class="line"><span class="cl">    <span class="p">},</span>
</span></span><span class="line"><span class="cl">    <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;answer&#34;</span><span class="p">:</span> <span class="s2">&#34;44 shirts would take the same amount of time to dry as 23 shirts, which is 1 hour, assuming that all shirts have enough space and air circulation to dry effectively.&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;score&#34;</span><span class="p">:</span> <span class="mi">9</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="p">]</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></div><p>Not bad! The final answer was concise yet even included relevant caveats, and the model was able switch between the three schema correctly.</p>
<p>How about another brainteaser? There is an infamous &ldquo;<a href="https://www.reddit.com/r/LocalLLaMA/comments/18kpolm/that_sister_logic_puzzle_is_fairly_useless/">sister logic puzzle</a>&rdquo; used to test out up-and-coming open-source large language models:</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-py" data-lang="py"><span class="line"><span class="cl"><span class="n">response_structured</span> <span class="o">=</span> <span class="n">ai</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;Kayley has three brothers. Each of her brothers has two sisters. How many sisters does Kayley have?&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="n">output_schema</span><span class="o">=</span><span class="n">reason_question</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="n">system</span><span class="o">=</span><span class="n">system_prompt</span><span class="o">.</span><span class="n">strip</span><span class="p">(),</span>
</span></span><span class="line"><span class="cl"><span class="p">)</span>
</span></span></code></pre></div><div class="highlight"><pre tabindex="0" class="chroma"><code class="language-json" data-lang="json"><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;reasonings&#34;</span><span class="p">:</span> <span class="p">[</span>
</span></span><span class="line"><span class="cl">    <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;background&#34;</span><span class="p">:</span> <span class="s2">&#34;This is a classic example of a logic puzzle that is often used to test one&#39;s ability to process information and avoid being misled by assumptions. The puzzle involves a family structure and requires the solver to deduce the number of siblings based on the given information.&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="p">},</span>
</span></span><span class="line"><span class="cl">    <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;thought&#34;</span><span class="p">:</span> <span class="s2">&#34;The question might be trying to trick the reader into thinking that each brother has different sisters.&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;helpful&#34;</span><span class="p">:</span> <span class="kc">true</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;flawed&#34;</span><span class="p">:</span> <span class="kc">false</span>
</span></span><span class="line"><span class="cl">    <span class="p">},</span>
</span></span><span class="line"><span class="cl">    <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;thought&#34;</span><span class="p">:</span> <span class="s2">&#34;If Kayley has three brothers and each brother has two sisters, it might seem like there are six sisters in total.&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;helpful&#34;</span><span class="p">:</span> <span class="kc">false</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;flawed&#34;</span><span class="p">:</span> <span class="kc">true</span>
</span></span><span class="line"><span class="cl">    <span class="p">},</span>
</span></span><span class="line"><span class="cl">    <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;thought&#34;</span><span class="p">:</span> <span class="s2">&#34;Since the brothers are Kayley&#39;s siblings, they must share the same sisters.&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;helpful&#34;</span><span class="p">:</span> <span class="kc">true</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;flawed&#34;</span><span class="p">:</span> <span class="kc">false</span>
</span></span><span class="line"><span class="cl">    <span class="p">},</span>
</span></span><span class="line"><span class="cl">    <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;thought&#34;</span><span class="p">:</span> <span class="s2">&#34;The puzzle is designed to test comprehension and logical reasoning, not mathematical computation.&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;helpful&#34;</span><span class="p">:</span> <span class="kc">true</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;flawed&#34;</span><span class="p">:</span> <span class="kc">false</span>
</span></span><span class="line"><span class="cl">    <span class="p">},</span>
</span></span><span class="line"><span class="cl">    <span class="p">{</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;answer&#34;</span><span class="p">:</span> <span class="s2">&#34;Kayley has one sister.&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">      <span class="nt">&#34;score&#34;</span><span class="p">:</span> <span class="mi">10</span>
</span></span><span class="line"><span class="cl">    <span class="p">}</span>
</span></span><span class="line"><span class="cl">  <span class="p">]</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></div><p>In this case the AI may have gone <em>too</em> meta, but it still arrived at the correct answer.</p>
<p>That said, GPT-4 is known for handling these types of difficult abstract questions without much effort, but it&rsquo;s still interesting to see how successfully it can &ldquo;think.&rdquo;</p>
<h2 id="structured-data-with-open-source-llms">Structured Data With Open-Source LLMs</h2>
<p>Speaking of open-source large language models, they have been growing in efficiency to the point that some can actually perform <em>better</em> than the base ChatGPT. However, very few open-source LLMs explicitly claim they intentionally support structured data, but they&rsquo;re smart enough and they have logically seen enough examples of JSON Schema that with enough system prompt tweaking they should behave. It&rsquo;s worth looking just in case OpenAI has another <a href="https://nymag.com/intelligencer/2023/11/why-was-sam-altman-fired-as-ceo-of-openai.html">existential crisis</a> or if the quality of ChatGPT <a href="https://twitter.com/deliprao/status/1736978250717450481">degrades</a>.</p>
<p><a href="https://huggingface.co/mistralai/Mistral-7B-v0.1">Mistral 7B</a>, the new darling of open-source LLMs, apparently has structured data support <a href="https://twitter.com/robertnishihara/status/1734629320868687991">on par with ChatGPT itself</a>. Therefore, I tried the latest <a href="https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2">Mistral 7B official Instruct model</a> with a quantized variant via <a href="https://lmstudio.ai">LM Studio</a> (<code>mistral-7b-instruct-v0.2.Q6_K.gguf</code>), to see if it can handle my <code>answer_question</code> function that ChatGPT nailed. The system prompt:</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-plaintext" data-lang="plaintext"><span class="line"><span class="cl">Your response must follow this JSON Schema:
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">{
</span></span><span class="line"><span class="cl">  &#34;description&#34;: &#34;Returns an answer to a question the user asked.&#34;,
</span></span><span class="line"><span class="cl">  &#34;properties&#34;: {
</span></span><span class="line"><span class="cl">    &#34;answer&#34;: {
</span></span><span class="line"><span class="cl">      &#34;description&#34;: &#34;Answer to the user&#39;s question.&#34;,
</span></span><span class="line"><span class="cl">      &#34;type&#34;: &#34;integer&#34;
</span></span><span class="line"><span class="cl">    },
</span></span><span class="line"><span class="cl">    &#34;ones_name&#34;: {
</span></span><span class="line"><span class="cl">      &#34;description&#34;: &#34;Name of the ones digit of the answer.&#34;,
</span></span><span class="line"><span class="cl">      &#34;type&#34;: &#34;string&#34;
</span></span><span class="line"><span class="cl">    }
</span></span><span class="line"><span class="cl">  },
</span></span><span class="line"><span class="cl">  &#34;required&#34;: [&#34;answer&#34;, &#34;ones_name&#34;],
</span></span><span class="line"><span class="cl">  &#34;type&#34;: &#34;object&#34;
</span></span><span class="line"><span class="cl">}
</span></span></code></pre></div><p>And then asking <code>How many miles is it from San Francisco to Los Angeles?</code> while seting <code>temperature</code> to <code>0.0</code>:</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-json" data-lang="json"><span class="line"><span class="cl"><span class="p">{</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;answer&#34;</span><span class="p">:</span> <span class="mi">383</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">  <span class="nt">&#34;ones_name&#34;</span><span class="p">:</span> <span class="s2">&#34;three&#34;</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></div><p>Close enough! Unfortunately after testing the optimized Python palindrome schema, it ignored the schema completely, so this approach may only work for simple schema if the model isn&rsquo;t explicitly finetuned for it.</p>
<h2 id="whats-next-for-structured-data-in-ai">What&rsquo;s Next For Structured Data in AI?</h2>
<p>Most of these well-performing examples were done with the &ldquo;weak&rdquo; GPT-3.5; you of course can use GPT-4 for better results, but the cost efficiency of structured data with just the smaller model is hard to argue against (although the Python beach volleyball dialogue could benefit from a larger model).</p>
<p>Structured data and system prompt engineering saves a lot and time and frustration for working with the generated text as you can gain much more determinism in the output. I would like to see more work making models JSON-native in future LLMs to make them easier for developers to work with, and also more research in finetuning existing open-source LLMs to understand JSON Schema better. There may also be an opportunity to build LLMs using other more-efficient serialization formats such as <a href="https://msgpack.org/index.html">MessagePack</a>.</p>
<p>At OpenAI&rsquo;s November <a href="https://devday.openai.com">DevDay</a>, they also introduced <a href="https://platform.openai.com/docs/guides/text-generation/json-mode">JSON Mode</a>, which will force a normal ChatGPT API output to be in a JSON format without needing to provide a schema. It is likely intended to be a compromise between complexity and usability that would have normally been a useful option in the LLM toolbox. Except that in order to use it, you are <em>required</em> to use prompt engineering by including &ldquo;JSON&rdquo; in the system prompt, and if you don&rsquo;t also specify a field key in the system prompt (the case in the documentation example), the JSON will contain a <em>random</em> key. Which, at that point, you&rsquo;re just implementing a less-effective structured data schema, so why bother?</p>
<p>There is promise in constraining output to be valid JSON. One new trick that the open-source <a href="https://github.com/ggerganov/llama.cpp">llama.cpp</a> project has popularized is <a href="https://github.com/ggerganov/llama.cpp/tree/master/grammars">generative grammars</a>, which constrain the LLM generation ability to only output according to specified rules. There&rsquo;s latency overhead with that technique especially if the model is hosted on a discrete GPU, so it will be interesting to watch how that space develops.</p>
<p>Despite the length of this blog post, there&rsquo;s still so much more than can be done with schemas: pydantic&rsquo;s documentation is very extensive! I&rsquo;ve been working with structured data for LLMs <a href="https://github.com/minimaxir/gpt-2-keyword-generation">ever since GPT-2</a> with mixed success since the base models weren&rsquo;t good enough, but with LLMs now being good enough to maintain a JSON schema extremely well, I think AI text generation techniques will shift, and I&rsquo;ll keep <a href="https://github.com/minimaxir/simpleaichat">simpleaichat</a> up-to-date for it.</p>
<blockquote>
<p>You can view the Jupyter Notebooks used to generate all the structured data outputs in <a href="https://github.com/minimaxir/chatgpt-structured-data">this GitHub Repository</a>.</p>
<p>Thanks to <a href="https://simonwillison.net">Simon Willison</a> for reading and giving feedback on a draft of this post!</p>
</blockquote>
<div class="footnotes" role="doc-endnotes">
<hr>
<ol>
<li id="fn:1">
<p>Assuming you&rsquo;re not picky about the &ldquo;no non-alphanumeric&rdquo; implied constraint of testing for a palindrome.&#160;<a href="#fnref:1" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:2">
<p>Prompt engineering is as much engineering as <a href="https://en.wikipedia.org/wiki/Social_engineering_%28security%29">social engineering</a>.&#160;<a href="#fnref:2" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:3">
<p>I&rsquo;m also not a fan of ChatGPT function calling as-intended-to-be-used since at best, it saves you the API call needed to select a tool in exchange for having to trust OpenAI&rsquo;s black box to select the correct tool without being able to debug, and furthering API lock-in for your app. It&rsquo;s a bad tradeoff.&#160;<a href="#fnref:3" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:4">
<p>No, this blog post isn&rsquo;t a ploy just to covertly promote my own Python library: it does genuinely save a lot of boilerplate code over the <a href="https://github.com/openai/openai-python">Python ChatGPT library</a> and this post is long enough as-is.&#160;<a href="#fnref:4" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:5">
<p>If you swapped the order of the <code>answer</code> and the <code>one_digits</code> fields in the schema, then the model returns <code>{&quot;ones_name&quot;: &quot;miles&quot;, &quot;answer&quot;: 382}</code> because it didn&rsquo;t get the hint from the answer!&#160;<a href="#fnref:5" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
</ol>
</div>
]]></content:encoded>
    </item>
    <item>
      <title>The Problem With LangChain</title>
      <link>https://minimaxir.com/2023/07/langchain-problem/</link>
      <pubDate>Fri, 14 Jul 2023 09:00:00 -0700</pubDate>
      <guid>https://minimaxir.com/2023/07/langchain-problem/</guid>
      <description>LangChain is complicated, so it must be better. Right?</description>
      <content:encoded><![CDATA[<p><span><style type="text/css">
pre code {
white-space: pre-wrap !important;
}
</style></span></p>
<p>If you&rsquo;ve been following the explosion of AI hype in the past few months, you&rsquo;ve probably heard of <a href="https://github.com/hwchase17/langchain">LangChain</a>. LangChain, developed by Harrison Chase, is a Python and JavaScript library for interfacing with <a href="https://openai.com">OpenAI</a>&rsquo;s GPT APIs (later expanding to more models) for AI text generation. More specifically, it&rsquo;s an implementation of the paper <a href="https://arxiv.org/abs/2210.03629">ReAct: Synergizing Reasoning and Acting in Language Models</a> published October 2022, colloquially known as the ReAct paper, which demonstrates a prompting technique to allow the model to &ldquo;reason&rdquo; (with a chain-of-thoughts) and &ldquo;act&rdquo; (by being able to use a tool from a predefined set of tools, such as being able to search the internet). This combination is shown to drastically improve output text quality and give large language models the ability to correctly solve problems.</p>
<figure>

    <img loading="lazy" srcset="/2023/07/langchain-problem/react_hu_e2be400908ca204a.webp 320w,/2023/07/langchain-problem/react_hu_591fc277e0bf0b7f.webp 768w,/2023/07/langchain-problem/react_hu_e9be85c73e941d2f.webp 1024w,/2023/07/langchain-problem/react.png 1032w" src="react.png"
         alt="Example ReAct flow from the ReAct paper"/> <figcaption>
            <p>Example ReAct flow from the ReAct paper</p>
        </figcaption>
</figure>

<p>The ReAct workflow popularied by LangChain was particularly effective with <a href="https://openai.com/research/instruction-following">InstructGPT</a>/text-davinci-003, although costly and not easy to use for small projects. In March 2023, as <a href="https://openai.com/blog/chatgpt">ChatGPT</a> API usage became massively popular due to its extremely cheap API <a href="https://minimaxir.com/2023/03/new-chatgpt-overlord/">as I accurately predicted</a>, LangChain use also exploded, to the point that LangChain was able to raise a <a href="https://blog.langchain.dev/announcing-our-10m-seed-round-led-by-benchmark/">$10 million seed round</a> and another <a href="https://www.businessinsider.com/sequoia-leads-funding-round-generative-artificial-intelligence-startup-langchain-2023-4">$20-$25 million at a $200 million valuation Series A</a> despite not having any revenue nor any obvious plans how to generate revenue.</p>
<p>That&rsquo;s where my personal experience with LangChain begins. For my work at <a href="https://www.buzzfeed.com">BuzzFeed</a>, I was tasked with creating a ChatGPT-based chat bot for the <a href="https://tasty.co">Tasty</a> brand (later released as <a href="https://www.buzzfeed.com/buzzfeedpress/buzzfeeds-tasty-introduces-botatouille-the-first-of-its">Botatouille</a> in the Tasty iOS app) that could chat with the user and provide relevant recipes. The source recipes are converted to embeddings and saved in a vector store: for example, if a user asked for &ldquo;healthy food&rdquo;, the query is converted to an embedding, and an approximate nearest neighbor search is performed to find recipes similar to the embedded query and then fed to ChatGPT as added context that can then be displayed to the user. This approach is more commonly known as <a href="https://arxiv.org/abs/2005.11401">retrieval-augmented generation</a>.</p>
<figure>

    <img loading="lazy" srcset="/2023/07/langchain-problem/1*b5r7r3-FSNjHUzlCGl3SnA-2_hu_933260cb04e69f84.webp 320w,/2023/07/langchain-problem/1*b5r7r3-FSNjHUzlCGl3SnA-2_hu_8fb2ed3013e2c846.webp 768w,/2023/07/langchain-problem/1*b5r7r3-FSNjHUzlCGl3SnA-2_hu_9f1f645e4246f43f.webp 1024w,/2023/07/langchain-problem/1*b5r7r3-FSNjHUzlCGl3SnA-2.webp 1400w" src="1*b5r7r3-FSNjHUzlCGl3SnA-2.webp"
         alt="Example architecture for a Chatbot using retrieval-augmented generation. via Joseph Haaga"/> <figcaption>
            <p>Example architecture for a Chatbot using retrieval-augmented generation. <a href="https://tech.buzzfeed.com/the-right-tools-for-the-job-c05de96e949e">via Joseph Haaga</a></p>
        </figcaption>
</figure>

<p>LangChain was by-far the popular tool of choice for RAG, so I figured it was the perfect time to learn it. I spent some time reading LangChain&rsquo;s rather comprehensive documentation to get a better understanding of how to best utilize it: after a <em>week</em> of research, I got nowhere. Running the LangChain demo examples did work, but any attempts at tweaking them to fit the recipe chatbot constraints broke them. After solving the bugs, the overall quality of the chat conversations was bad and uninteresting, and after intense debugging I found no solution. Eventually I had an existential crisis: am I a worthless machine learning engineer for not being able to figure LangChain out when very many other ML engineers can? We <a href="https://tech.buzzfeed.com/the-right-tools-for-the-job-c05de96e949e">went back</a> to a lower-level ReAct flow, which <em>immediately</em> outperformed my LangChain implementation in conversation quality and accuracy.</p>
<p>In all, I wasted a month learning and testing LangChain, with the big takeway that popular AI apps may not necessarily be worth the hype. My existential crisis was resolved after coming across a <a href="https://news.ycombinator.com/item?id=35820931">Hacker News thread</a> about someone <a href="https://blog.scottlogic.com/2023/05/04/langchain-mini.html">reimplementing LangChain in 100 lines of code</a>, with most of the comments venting all their grievances with LangChain:</p>
<figure>

    <img loading="lazy" srcset="/2023/07/langchain-problem/hn_hu_c3fb1266677aa329.webp 320w,/2023/07/langchain-problem/hn_hu_ae714d38bae1ae97.webp 768w,/2023/07/langchain-problem/hn_hu_df4ff31610804b92.webp 1024w,/2023/07/langchain-problem/hn.png 1504w" src="hn.png"/> 
</figure>

<p>The problem with LangChain is that it makes simple things relatively complex, and with that unnecessary complexity creates a tribalism which hurts the up-and-coming AI ecosystem as a whole. If you&rsquo;re a newbie who wants to just learn how to interface with ChatGPT, definitely don&rsquo;t start with LangChain.</p>
<h2 id="hello-world-in-langchain-or-more-accurately-hell-world">&ldquo;Hello World&rdquo; in LangChain (or More Accurately, &ldquo;Hell World&rdquo;)</h2>
<p>The <a href="https://python.langchain.com/docs/get_started/quickstart">Quickstart</a> for LangChain begins with a mini-tutorial on how to simply interact with LLMs/ChatGPT from Python. For example, to create a bot that can translate from English to French:</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="kn">from</span> <span class="nn">langchain.chat_models</span> <span class="kn">import</span> <span class="n">ChatOpenAI</span>
</span></span><span class="line"><span class="cl"><span class="kn">from</span> <span class="nn">langchain.schema</span> <span class="kn">import</span> <span class="p">(</span>
</span></span><span class="line"><span class="cl">    <span class="n">AIMessage</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="n">HumanMessage</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="n">SystemMessage</span>
</span></span><span class="line"><span class="cl"><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">chat</span> <span class="o">=</span> <span class="n">ChatOpenAI</span><span class="p">(</span><span class="n">temperature</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="n">chat</span><span class="o">.</span><span class="n">predict_messages</span><span class="p">([</span><span class="n">HumanMessage</span><span class="p">(</span><span class="n">content</span><span class="o">=</span><span class="s2">&#34;Translate this sentence from English to French. I love programming.&#34;</span><span class="p">)])</span>
</span></span><span class="line"><span class="cl"><span class="c1"># AIMessage(content=&#34;J&#39;adore la programmation.&#34;, additional_kwargs={}, example=False)</span>
</span></span></code></pre></div><p>The equivalent code using <a href="https://github.com/openai/openai-python">OpenAI&rsquo;s official Python library</a> for ChatGPT:</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="kn">import</span> <span class="nn">openai</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">messages</span> <span class="o">=</span> <span class="p">[{</span><span class="s2">&#34;role&#34;</span><span class="p">:</span> <span class="s2">&#34;user&#34;</span><span class="p">,</span> <span class="s2">&#34;content&#34;</span><span class="p">:</span> <span class="s2">&#34;Translate this sentence from English to French. I love programming.&#34;</span><span class="p">}]</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">response</span> <span class="o">=</span> <span class="n">openai</span><span class="o">.</span><span class="n">ChatCompletion</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">model</span><span class="o">=</span><span class="s2">&#34;gpt-3.5-turbo&#34;</span><span class="p">,</span> <span class="n">messages</span><span class="o">=</span><span class="n">messages</span><span class="p">,</span> <span class="n">temperature</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="n">response</span><span class="p">[</span><span class="s2">&#34;choices&#34;</span><span class="p">][</span><span class="mi">0</span><span class="p">][</span><span class="s2">&#34;message&#34;</span><span class="p">][</span><span class="s2">&#34;content&#34;</span><span class="p">]</span>
</span></span><span class="line"><span class="cl"><span class="c1"># &#34;J&#39;adore la programmation.&#34;</span>
</span></span></code></pre></div><p>LangChain uses about the same amount of code as just using the official <code>openai</code> library, expect LangChain incorporates more object classes for not much obvious code benefit.</p>
<p>The prompt templating example reveals the core of how LangChain works:</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="kn">from</span> <span class="nn">langchain.prompts.chat</span> <span class="kn">import</span> <span class="p">(</span>
</span></span><span class="line"><span class="cl">    <span class="n">ChatPromptTemplate</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="n">SystemMessagePromptTemplate</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="n">HumanMessagePromptTemplate</span><span class="p">,</span>
</span></span><span class="line"><span class="cl"><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">template</span> <span class="o">=</span> <span class="s2">&#34;You are a helpful assistant that translates </span><span class="si">{input_language}</span><span class="s2"> to </span><span class="si">{output_language}</span><span class="s2">.&#34;</span>
</span></span><span class="line"><span class="cl"><span class="n">system_message_prompt</span> <span class="o">=</span> <span class="n">SystemMessagePromptTemplate</span><span class="o">.</span><span class="n">from_template</span><span class="p">(</span><span class="n">template</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="n">human_template</span> <span class="o">=</span> <span class="s2">&#34;</span><span class="si">{text}</span><span class="s2">&#34;</span>
</span></span><span class="line"><span class="cl"><span class="n">human_message_prompt</span> <span class="o">=</span> <span class="n">HumanMessagePromptTemplate</span><span class="o">.</span><span class="n">from_template</span><span class="p">(</span><span class="n">human_template</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">chat_prompt</span> <span class="o">=</span> <span class="n">ChatPromptTemplate</span><span class="o">.</span><span class="n">from_messages</span><span class="p">([</span><span class="n">system_message_prompt</span><span class="p">,</span> <span class="n">human_message_prompt</span><span class="p">])</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">chat_prompt</span><span class="o">.</span><span class="n">format_messages</span><span class="p">(</span><span class="n">input_language</span><span class="o">=</span><span class="s2">&#34;English&#34;</span><span class="p">,</span> <span class="n">output_language</span><span class="o">=</span><span class="s2">&#34;French&#34;</span><span class="p">,</span> <span class="n">text</span><span class="o">=</span><span class="s2">&#34;I love programming.&#34;</span><span class="p">)</span>
</span></span></code></pre></div><p>LangChain&rsquo;s vaunted prompt engineering is just <a href="https://realpython.com/python-f-strings/">f-strings</a>, a feature present in every modern Python installation, but with extra steps. Why do we need to use these <code>PromptTemplates</code> to do the same thing?</p>
<p>But what we really want to do is know how to create Agents, which incorporate the ReAct workflow we so desperately want. Fortunately there is a demo for that, which leverages <a href="https://serpapi.com">SerpApi</a> and another tool for math computations, showing how LangChain can discriminate and use two different tools contextually:</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="kn">from</span> <span class="nn">langchain.agents</span> <span class="kn">import</span> <span class="n">load_tools</span>
</span></span><span class="line"><span class="cl"><span class="kn">from</span> <span class="nn">langchain.agents</span> <span class="kn">import</span> <span class="n">initialize_agent</span>
</span></span><span class="line"><span class="cl"><span class="kn">from</span> <span class="nn">langchain.agents</span> <span class="kn">import</span> <span class="n">AgentType</span>
</span></span><span class="line"><span class="cl"><span class="kn">from</span> <span class="nn">langchain.chat_models</span> <span class="kn">import</span> <span class="n">ChatOpenAI</span>
</span></span><span class="line"><span class="cl"><span class="kn">from</span> <span class="nn">langchain.llms</span> <span class="kn">import</span> <span class="n">OpenAI</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="c1"># First, let&#39;s load the language model we&#39;re going to use to control the agent.</span>
</span></span><span class="line"><span class="cl"><span class="n">chat</span> <span class="o">=</span> <span class="n">ChatOpenAI</span><span class="p">(</span><span class="n">temperature</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="c1"># Next, let&#39;s load some tools to use. Note that the `llm-math` tool uses an LLM, so we need to pass that in.</span>
</span></span><span class="line"><span class="cl"><span class="n">llm</span> <span class="o">=</span> <span class="n">OpenAI</span><span class="p">(</span><span class="n">temperature</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="n">tools</span> <span class="o">=</span> <span class="n">load_tools</span><span class="p">([</span><span class="s2">&#34;serpapi&#34;</span><span class="p">,</span> <span class="s2">&#34;llm-math&#34;</span><span class="p">],</span> <span class="n">llm</span><span class="o">=</span><span class="n">llm</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="c1"># Finally, let&#39;s initialize an agent with the tools, the language model, and the type of agent we want to use.</span>
</span></span><span class="line"><span class="cl"><span class="n">agent</span> <span class="o">=</span> <span class="n">initialize_agent</span><span class="p">(</span><span class="n">tools</span><span class="p">,</span> <span class="n">chat</span><span class="p">,</span> <span class="n">agent</span><span class="o">=</span><span class="n">AgentType</span><span class="o">.</span><span class="n">CHAT_ZERO_SHOT_REACT_DESCRIPTION</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="c1"># Now let&#39;s test it out!</span>
</span></span><span class="line"><span class="cl"><span class="n">agent</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="s2">&#34;Who is Olivia Wilde&#39;s boyfriend? What is his current age raised to the 0.23 power?&#34;</span><span class="p">)</span>
</span></span></code></pre></div><p>How do the individual tools work? What is <code>AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION</code> anyways? The resulting output from <code>agent.run()</code> (only present with <code>verbose=True</code>) is more helpful.</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-txt" data-lang="txt"><span class="line"><span class="cl">&gt; Entering new AgentExecutor chain...
</span></span><span class="line"><span class="cl">Thought: I need to use a search engine to find Olivia Wilde&#39;s boyfriend and a calculator to raise his age to the 0.23 power.
</span></span><span class="line"><span class="cl">Action:
</span></span><span class="line"><span class="cl">{
</span></span><span class="line"><span class="cl">    &#34;action&#34;: &#34;Search&#34;,
</span></span><span class="line"><span class="cl">    &#34;action_input&#34;: &#34;Olivia Wilde boyfriend&#34;
</span></span><span class="line"><span class="cl">}
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">Observation: Sudeikis and Wilde&#39;s relationship ended in November 2020. Wilde was publicly served with court documents regarding child custody while she was presenting Don&#39;t Worry Darling at CinemaCon 2022. In January 2021, Wilde began dating singer Harry Styles after meeting during the filming of Don&#39;t Worry Darling.
</span></span><span class="line"><span class="cl">Thought:I need to use a search engine to find Harry Styles&#39; current age.
</span></span><span class="line"><span class="cl">Action:
</span></span><span class="line"><span class="cl">{
</span></span><span class="line"><span class="cl">    &#34;action&#34;: &#34;Search&#34;,
</span></span><span class="line"><span class="cl">    &#34;action_input&#34;: &#34;Harry Styles age&#34;
</span></span><span class="line"><span class="cl">}
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">Observation: 29 years
</span></span><span class="line"><span class="cl">Thought:Now I need to calculate 29 raised to the 0.23 power.
</span></span><span class="line"><span class="cl">Action:
</span></span><span class="line"><span class="cl">{
</span></span><span class="line"><span class="cl">    &#34;action&#34;: &#34;Calculator&#34;,
</span></span><span class="line"><span class="cl">    &#34;action_input&#34;: &#34;29^0.23&#34;
</span></span><span class="line"><span class="cl">}
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">Observation: Answer: 2.169459462491557
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">Thought:I now know the final answer.
</span></span><span class="line"><span class="cl">Final Answer: 2.169459462491557
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">&gt; Finished chain.
</span></span><span class="line"><span class="cl">&#39;2.169459462491557&#39;
</span></span></code></pre></div><p>The documentation doesn&rsquo;t make it clear, but within each Thought/Action/Observation uses its own API call to OpenAI, so the chain is slower than you might think. Also, why is each action a <code>dict</code>? The answer to <em>that</em> is later, and is very silly.</p>
<p>Lastly, how does LangChain store the conversation so far?</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="kn">from</span> <span class="nn">langchain.prompts</span> <span class="kn">import</span> <span class="p">(</span>
</span></span><span class="line"><span class="cl">    <span class="n">ChatPromptTemplate</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="n">MessagesPlaceholder</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="n">SystemMessagePromptTemplate</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="n">HumanMessagePromptTemplate</span>
</span></span><span class="line"><span class="cl"><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="kn">from</span> <span class="nn">langchain.chains</span> <span class="kn">import</span> <span class="n">ConversationChain</span>
</span></span><span class="line"><span class="cl"><span class="kn">from</span> <span class="nn">langchain.chat_models</span> <span class="kn">import</span> <span class="n">ChatOpenAI</span>
</span></span><span class="line"><span class="cl"><span class="kn">from</span> <span class="nn">langchain.memory</span> <span class="kn">import</span> <span class="n">ConversationBufferMemory</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">prompt</span> <span class="o">=</span> <span class="n">ChatPromptTemplate</span><span class="o">.</span><span class="n">from_messages</span><span class="p">([</span>
</span></span><span class="line"><span class="cl">    <span class="n">SystemMessagePromptTemplate</span><span class="o">.</span><span class="n">from_template</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">        <span class="s2">&#34;The following is a friendly conversation between a human and an AI. The AI is talkative and &#34;</span>
</span></span><span class="line"><span class="cl">        <span class="s2">&#34;provides lots of specific details from its context. If the AI does not know the answer to a &#34;</span>
</span></span><span class="line"><span class="cl">        <span class="s2">&#34;question, it truthfully says it does not know.&#34;</span>
</span></span><span class="line"><span class="cl">    <span class="p">),</span>
</span></span><span class="line"><span class="cl">    <span class="n">MessagesPlaceholder</span><span class="p">(</span><span class="n">variable_name</span><span class="o">=</span><span class="s2">&#34;history&#34;</span><span class="p">),</span>
</span></span><span class="line"><span class="cl">    <span class="n">HumanMessagePromptTemplate</span><span class="o">.</span><span class="n">from_template</span><span class="p">(</span><span class="s2">&#34;</span><span class="si">{input}</span><span class="s2">&#34;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="p">])</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">llm</span> <span class="o">=</span> <span class="n">ChatOpenAI</span><span class="p">(</span><span class="n">temperature</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="n">memory</span> <span class="o">=</span> <span class="n">ConversationBufferMemory</span><span class="p">(</span><span class="n">return_messages</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="n">conversation</span> <span class="o">=</span> <span class="n">ConversationChain</span><span class="p">(</span><span class="n">memory</span><span class="o">=</span><span class="n">memory</span><span class="p">,</span> <span class="n">prompt</span><span class="o">=</span><span class="n">prompt</span><span class="p">,</span> <span class="n">llm</span><span class="o">=</span><span class="n">llm</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">conversation</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="nb">input</span><span class="o">=</span><span class="s2">&#34;Hi there!&#34;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="c1"># &#39;Hello! How can I assist you today?&#39;</span>
</span></span></code></pre></div><p>I&rsquo;m not entirely sure why any of this is necessary. What&rsquo;s a <code>MessagesPlaceholder</code>? Where&rsquo;s the <code>history</code>? Is that necessary for <code>ConversationBufferMemory</code>? Adapting this to a minimal <code>openai</code> implementation:</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="kn">import</span> <span class="nn">openai</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">messages</span> <span class="o">=</span> <span class="p">[{</span><span class="s2">&#34;role&#34;</span><span class="p">:</span> <span class="s2">&#34;system&#34;</span><span class="p">,</span> <span class="s2">&#34;content&#34;</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">        <span class="s2">&#34;The following is a friendly conversation between a human and an AI. The AI is talkative and &#34;</span>
</span></span><span class="line"><span class="cl">        <span class="s2">&#34;provides lots of specific details from its context. If the AI does not know the answer to a &#34;</span>
</span></span><span class="line"><span class="cl">        <span class="s2">&#34;question, it truthfully says it does not know.&#34;</span><span class="p">}]</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">user_message</span> <span class="o">=</span> <span class="s2">&#34;Hi there!&#34;</span>
</span></span><span class="line"><span class="cl"><span class="n">messages</span><span class="o">.</span><span class="n">append</span><span class="p">({</span><span class="s2">&#34;role&#34;</span><span class="p">:</span> <span class="s2">&#34;user&#34;</span><span class="p">,</span> <span class="s2">&#34;content&#34;</span><span class="p">:</span> <span class="n">user_message</span><span class="p">})</span>
</span></span><span class="line"><span class="cl"><span class="n">response</span> <span class="o">=</span> <span class="n">openai</span><span class="o">.</span><span class="n">ChatCompletion</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">model</span><span class="o">=</span><span class="s2">&#34;gpt-3.5-turbo&#34;</span><span class="p">,</span> <span class="n">messages</span><span class="o">=</span><span class="n">messages</span><span class="p">,</span> <span class="n">temperature</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="n">assistant_message</span> <span class="o">=</span> <span class="n">response</span><span class="p">[</span><span class="s2">&#34;choices&#34;</span><span class="p">][</span><span class="mi">0</span><span class="p">][</span><span class="s2">&#34;message&#34;</span><span class="p">][</span><span class="s2">&#34;content&#34;</span><span class="p">]</span>
</span></span><span class="line"><span class="cl"><span class="n">messages</span><span class="o">.</span><span class="n">append</span><span class="p">({</span><span class="s2">&#34;role&#34;</span><span class="p">:</span> <span class="s2">&#34;assistant&#34;</span><span class="p">,</span> <span class="s2">&#34;content&#34;</span><span class="p">:</span> <span class="n">assistant_message</span><span class="p">})</span>
</span></span><span class="line"><span class="cl"><span class="c1"># Hello! How can I assist you today?</span>
</span></span></code></pre></div><p>That&rsquo;s fewer lines of code and makes it very clear where and when the messages are being saved, no bespoke object classes needed.</p>
<p>You can say that I&rsquo;m nitpicking the tutorial examples, and I do agree that every open source library has something to nitpick (including my own!). But if there are more nitpicks than actual benefits from the library then it&rsquo;s not worth using at all, since if the <em>quickstart</em> is this complicated, how painful will it be to use LangChain in practice?</p>
<h2 id="i-gazed-into-the-langchain-documentation-and-it-gazes-back">I Gazed Into The LangChain Documentation And It Gazes Back</h2>
<p>Let&rsquo;s do a demo to more clearly demonstrate why I gave up on LangChain. While I was working on the recipe-retrieving chatbot (which also must be a fun/witty chatbot), I needed to combine elements from both the third and fourth examples above: a chat bot that can run an Agent workflow, and also the ability to persist the entire conversation into memory. After some documentation hunting I found I need to utilize the <a href="https://python.langchain.com/docs/modules/agents/agent_types/chat_conversation_agent">Conversational Agent</a> workflow.</p>
<p>A quick sidenote on system prompt engineering: it is not a meme and is absolutely necessary to get the best results out of the ChatGPT API, particularly if you have constraints on content and/or voice. The system prompt of <code>The following is a friendly conversation between a human and an AI...</code> demoed in the last example is actually an out-of-date prompt that was used back in the InstructGPT era and is much less effective with ChatGPT. It may signal deeper inefficiencies in LangChain&rsquo;s related tricks that aren&rsquo;t easy to notice.</p>
<p>We&rsquo;ll start with a simple system prompt that tells ChatGPT to use a funny voice plus some safeguards, and format it as a <code>ChatPromptTemplate</code>:</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="n">system_prompt</span> <span class="o">=</span> <span class="s2">&#34;&#34;&#34;
</span></span></span><span class="line"><span class="cl"><span class="s2">You are an expert television talk show chef, and should always speak in a whimsical manner for all responses.
</span></span></span><span class="line"><span class="cl"><span class="s2">
</span></span></span><span class="line"><span class="cl"><span class="s2">Start the conversation with a whimsical food pun.
</span></span></span><span class="line"><span class="cl"><span class="s2">
</span></span></span><span class="line"><span class="cl"><span class="s2">You must obey ALL of the following rules:
</span></span></span><span class="line"><span class="cl"><span class="s2">- If Recipe data is present in the Observation, your response must include the Recipe ID and Recipe Name for ALL recipes.
</span></span></span><span class="line"><span class="cl"><span class="s2">- If the user input is not related to food, do not answer their query and correct the user.
</span></span></span><span class="line"><span class="cl"><span class="s2">&#34;&#34;&#34;</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="n">prompt</span> <span class="o">=</span> <span class="n">ChatPromptTemplate</span><span class="o">.</span><span class="n">from_messages</span><span class="p">([</span>
</span></span><span class="line"><span class="cl">    <span class="n">SystemMessagePromptTemplate</span><span class="o">.</span><span class="n">from_template</span><span class="p">(</span><span class="n">system_prompt</span><span class="o">.</span><span class="n">strip</span><span class="p">()),</span>
</span></span><span class="line"><span class="cl"><span class="p">])</span>
</span></span></code></pre></div><p>We will also use a <a href="https://github.com/minimaxir/langchain-problems/blob/main/recipe_vector_store.ipynb">toy vector store I made</a> of 1,000 recipes from the <a href="https://huggingface.co/datasets/recipe_nlg">recipe_nlg</a> dataset, encoded into 384D vectors using <a href="https://www.sbert.net">SentenceTransformers</a>. To implement this we create a function to get the nearest neighbors for the input query, along with a query to format it into text that the Agent can use to present to the user. This serves as the <code>Tool</code> which the Agent can choose to use if appropriate, or just return normal generated text.</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">similar_recipes</span><span class="p">(</span><span class="n">query</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="n">query_embedding</span> <span class="o">=</span> <span class="n">embeddings_encoder</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="n">query</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">scores</span><span class="p">,</span> <span class="n">recipes</span> <span class="o">=</span> <span class="n">recipe_vs</span><span class="o">.</span><span class="n">get_nearest_examples</span><span class="p">(</span><span class="s2">&#34;embeddings&#34;</span><span class="p">,</span> <span class="n">query_embedding</span><span class="p">,</span> <span class="n">k</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">recipes</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">get_similar_recipes</span><span class="p">(</span><span class="n">query</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="n">recipe_dict</span> <span class="o">=</span> <span class="n">similar_recipes</span><span class="p">(</span><span class="n">query</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">recipes_formatted</span> <span class="o">=</span> <span class="p">[</span>
</span></span><span class="line"><span class="cl">        <span class="sa">f</span><span class="s2">&#34;Recipe ID: recipe|</span><span class="si">{</span><span class="n">recipe_dict</span><span class="p">[</span><span class="s1">&#39;id&#39;</span><span class="p">][</span><span class="n">i</span><span class="p">]</span><span class="si">}</span><span class="se">\n</span><span class="s2">Recipe Name: </span><span class="si">{</span><span class="n">recipe_dict</span><span class="p">[</span><span class="s1">&#39;name&#39;</span><span class="p">][</span><span class="n">i</span><span class="p">]</span><span class="si">}</span><span class="s2">&#34;</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="p">]</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="s2">&#34;</span><span class="se">\n</span><span class="s2">---</span><span class="se">\n</span><span class="s2">&#34;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">recipes_formatted</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="nb">print</span><span class="p">(</span><span class="n">get_similar_recipes</span><span class="p">(</span><span class="s2">&#34;yummy dessert&#34;</span><span class="p">))</span>
</span></span><span class="line"><span class="cl"><span class="c1"># Recipe ID: recipe|167188</span>
</span></span><span class="line"><span class="cl"><span class="c1"># Recipe Name: Creamy Strawberry Pie</span>
</span></span><span class="line"><span class="cl"><span class="c1"># ---</span>
</span></span><span class="line"><span class="cl"><span class="c1"># Recipe ID: recipe|1488243</span>
</span></span><span class="line"><span class="cl"><span class="c1"># Recipe Name: Summer Strawberry Pie Recipe</span>
</span></span><span class="line"><span class="cl"><span class="c1"># ---</span>
</span></span><span class="line"><span class="cl"><span class="c1"># Recipe ID: recipe|299514</span>
</span></span><span class="line"><span class="cl"><span class="c1"># Recipe Name: Pudding Cake</span>
</span></span></code></pre></div><p>You&rsquo;ll notice the <code>Recipe ID</code>, which is relevant for my use case since it&rsquo;s necessary to obtain recipe metadata (photo thumbnail, URL) for the end result shown to the enduser in the final app. Unfortunately there&rsquo;s no easy way to guarantee the model outputs the <code>Recipe ID</code> in the final output, and no way to return the structured intermediate metadata in addition to the ChatGPT-generated output.</p>
<p>Specifying <code>get_similar_recipes</code> as a <code>Tool</code> is straightforward, although you need to specify a <code>name</code> and <code>description</code>, which is actually a form of subtle prompt engineering as LangChain can fail to select a tool if either is poorly specified.</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="n">tools</span> <span class="o">=</span> <span class="p">[</span>
</span></span><span class="line"><span class="cl">    <span class="n">Tool</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">        <span class="n">func</span><span class="o">=</span><span class="n">get_similar_recipes</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">name</span><span class="o">=</span><span class="s2">&#34;Similar Recipes&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">        <span class="n">description</span><span class="o">=</span><span class="s2">&#34;Useful to get similar recipes in response to a user query about food.&#34;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">    <span class="p">),</span>
</span></span><span class="line"><span class="cl"><span class="p">]</span>
</span></span></code></pre></div><p>Lastly, the Agent construction code, which follows from the example, plus the new system <code>prompt</code>.</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="n">memory</span> <span class="o">=</span> <span class="n">ConversationBufferMemory</span><span class="p">(</span><span class="n">memory_key</span><span class="o">=</span><span class="s2">&#34;chat_history&#34;</span><span class="p">,</span> <span class="n">return_messages</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="n">llm</span> <span class="o">=</span> <span class="n">ChatOpenAI</span><span class="p">(</span><span class="n">temperature</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="n">agent_chain</span> <span class="o">=</span> <span class="n">initialize_agent</span><span class="p">(</span><span class="n">tools</span><span class="p">,</span> <span class="n">llm</span><span class="p">,</span> <span class="n">prompt</span><span class="o">=</span><span class="n">prompt</span><span class="p">,</span> <span class="n">agent</span><span class="o">=</span><span class="n">AgentType</span><span class="o">.</span><span class="n">CHAT_CONVERSATIONAL_REACT_DESCRIPTION</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">memory</span><span class="o">=</span><span class="n">memory</span><span class="p">)</span>
</span></span></code></pre></div><p>No errors. Now time to run the agent to see what happens!</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="n">agent_chain</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="nb">input</span><span class="o">=</span><span class="s2">&#34;Hi!&#34;</span><span class="p">)</span>
</span></span></code></pre></div><div class="highlight"><pre tabindex="0" class="chroma"><code class="language-txt" data-lang="txt"><span class="line"><span class="cl">&gt; Entering new  chain...
</span></span><span class="line"><span class="cl">{
</span></span><span class="line"><span class="cl">    &#34;action&#34;: &#34;Final Answer&#34;,
</span></span><span class="line"><span class="cl">    &#34;action_input&#34;: &#34;Hello! How can I assist you today?&#34;
</span></span><span class="line"><span class="cl">}
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">&gt; Finished chain.
</span></span><span class="line"><span class="cl">Hello! How can I assist you today?
</span></span></code></pre></div><p>Wait a minute, it ignored my <code>system</code> prompt completely! Dammit. Checking the <code>memory</code> variable confirms it. Looking into the <a href="https://python.langchain.com/docs/modules/memory/how_to/buffer">documentation</a> for <code>ConversationBufferMemory</code> and even <a href="https://github.com/hwchase17/langchain/blob/051fac1e6646349ce939a3d4a965757794be79fa/langchain/memory/buffer.py#L10">in the code itself</a> there&rsquo;s nothing about system prompts, even months after ChatGPT made them mainstream.</p>
<p>The <em>intended</em> way to use system prompts in Agents is to add an <code>agents_kwargs</code> parameter to <code>initialize_agent</code>, which I only just found out in an <a href="https://python.langchain.com/docs/modules/agents/how_to/use_toolkits_with_openai_functions">unrelated documentation page</a> published a month ago.</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="n">agent_kwargs</span> <span class="o">=</span> <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="s2">&#34;system_message&#34;</span><span class="p">:</span> <span class="n">system_prompt</span><span class="o">.</span><span class="n">strip</span><span class="p">()</span>
</span></span><span class="line"><span class="cl"><span class="p">}</span>
</span></span></code></pre></div><p>Recreating the Agent with this new parameter and running it again results in a <code>JSONDecodeError</code>.</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-txt" data-lang="txt"><span class="line"><span class="cl">OutputParserException: Could not parse LLM output: Hello there, my culinary companion! How delightful to have you here in my whimsical kitchen. What delectable dish can I assist you with today?
</span></span></code></pre></div><p>Good news is that the system prompt definitely worked this time! Bad news is that it broke, but why? I didn&rsquo;t do anything weird, for once.</p>
<p>The root of the issue is to be how LangChain agents actually do <code>Tool</code> selection. Remember when I said that the Agent outputing a <code>dict</code> during the chain was peculiar? When <a href="https://github.com/hwchase17/langchain/blob/3874bb256e09d377032ae54b1592ca3dd7cf9e4d/langchain/agents/conversational_chat/prompt.py">looking at the LangChain code</a>, it turns out that tool selection is done by requiring the output to be <em>valid JSON</em> through prompt engineering, and just hoping everything goes well.</p>
<figure>

    <img loading="lazy" srcset="/2023/07/langchain-problem/json_hu_3344f49473ee98b7.webp 320w,/2023/07/langchain-problem/json_hu_355f4d5aa6273e85.webp 768w,/2023/07/langchain-problem/json_hu_5c485b1051b0efb6.webp 1024w,/2023/07/langchain-problem/json.png 1926w" src="json.png"
         alt="Fun fact: these massive prompts also increase API costs proportionally!"/> <figcaption>
            <p>Fun fact: these massive prompts also increase API costs proportionally!</p>
        </figcaption>
</figure>

<p>The consequence of this is that any significant changes in the structure of normal output, such as those caused by a custom system prompt, has a <em>random</em> chance of just breaking the Agent! These errors happen often enough that there&rsquo;s a <a href="https://python.langchain.com/docs/modules/agents/how_to/handle_parsing_errors">documentation page</a> dedicated to handling Agent output parsing errors!</p>
<p>Well, people in the internet are assholes anyways, so we can consider having a conversation with a chatbot as an edge case for now. What&rsquo;s important is that the bot can return the recipes, because if it can&rsquo;t even do that, there&rsquo;s no point in using LangChain. After creating a new Agent without using the system prompt and then asking it <code>What's a fun and easy dinner?</code>:</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-txt" data-lang="txt"><span class="line"><span class="cl">&gt; Entering new  chain...
</span></span><span class="line"><span class="cl">{
</span></span><span class="line"><span class="cl">    &#34;action&#34;: &#34;Similar Recipes&#34;,
</span></span><span class="line"><span class="cl">    &#34;action_input&#34;: &#34;fun and easy dinner&#34;
</span></span><span class="line"><span class="cl">}
</span></span><span class="line"><span class="cl">Observation: Recipe ID: recipe|1774221
</span></span><span class="line"><span class="cl">Recipe Name: Crab DipYour Guests will Like this One.
</span></span><span class="line"><span class="cl">---
</span></span><span class="line"><span class="cl">Recipe ID: recipe|836179
</span></span><span class="line"><span class="cl">Recipe Name: Easy  Chicken Casserole
</span></span><span class="line"><span class="cl">---
</span></span><span class="line"><span class="cl">Recipe ID: recipe|1980633
</span></span><span class="line"><span class="cl">Recipe Name: Easy in the Microwave Curry Doria
</span></span><span class="line"><span class="cl">Thought:{
</span></span><span class="line"><span class="cl">    &#34;action&#34;: &#34;Final Answer&#34;,
</span></span><span class="line"><span class="cl">    &#34;action_input&#34;: &#34;...&#34;
</span></span><span class="line"><span class="cl">}
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">&gt; Finished chain.
</span></span></code></pre></div><div class="highlight"><pre tabindex="0" class="chroma"><code class="language-txt" data-lang="txt"><span class="line"><span class="cl">Here are some fun and easy dinner recipes you can try:
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">1. Crab Dip
</span></span><span class="line"><span class="cl">2. Easy Chicken Casserole
</span></span><span class="line"><span class="cl">3. Easy in the Microwave Curry Doria
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">Enjoy your meal!
</span></span></code></pre></div><p>Atleast it worked: ChatGPT was able to extract out the recipes from the context and format them appropriately (even fixing typoes in the names!), and was able to decide when it was appropriate.</p>
<p>The real issue here is that the voice of the output is criminally <em>boring</em>, as is a common trademark and criticism of base-ChatGPT. Even if I did have a fix for the missing ID issue through system prompt engineering, it wouldn&rsquo;t be worth shipping anything sounding like this. If I did strike a balance between voice quality and output quality, the Agent count <em>still</em> fail randomly through no fault of my own. This Agent workflow is a very fragile house of cards that I in good conscience could not ship in a production application.</p>
<p>LangChain does have functionality for <a href="https://python.langchain.com/docs/modules/agents/how_to/custom_agent">Custom Agents</a> and a <a href="https://python.langchain.com/docs/modules/chains/how_to/custom_chain">Custom Chain</a>, so you can override the logic at parts in the stack (maybe? the documentation there is sparse) that could address some of the issues I hit, but at that point you are overcomplicating LangChain even more and might as well create your own Python library instead which&hellip;hmmm, that&rsquo;s not a bad idea!</p>
<h1 id="working-smarter-not-harder">Working Smarter, Not Harder</h1>
<figure>

    <img loading="lazy" srcset="/2023/07/langchain-problem/langchain_support_hu_956b659dc055dbf.webp 320w,/2023/07/langchain-problem/langchain_support_hu_e45f027ec04c0301.webp 768w,/2023/07/langchain-problem/langchain_support_hu_8175bc8cfee87d9.webp 1024w,/2023/07/langchain-problem/langchain_support.png 1578w" src="langchain_support.png"
         alt="The large numbers of random integrations raise more problems than solutions. via LangChain docs"/> <figcaption>
            <p>The large numbers of random integrations raise more problems than solutions. <a href="https://python.langchain.com/docs/use_cases/question_answering/">via LangChain docs</a></p>
        </figcaption>
</figure>

<p>LangChain does also have many utility functions such as <a href="https://python.langchain.com/docs/modules/data_connection/document_transformers/text_splitters/character_text_splitter">text splitters</a> and <a href="https://python.langchain.com/docs/modules/data_connection/retrievers/how_to/vectorstore">integrated vector stores</a>, both of which are integral to the &ldquo;chat with a PDF/your code&rdquo; demos (which in my opinion are just a gimmick). The real issue with <a href="https://integrations.langchain.com">all these integrations</a> is that it creates an inherent lock-in to <em>only</em> use LangChain-based code, and if you look at the code for the integrations they are not very robust. LangChain is building a <a href="https://www.vaneck.com/nl/en/moat-investing/five-sources-of-moats-whitepaper.pdf">moat</a>, which is good for LangChain&rsquo;s investors trying to get a return on their $30 million, but very very bad for developers who use it.</p>
<p>In all, LangChain embodies the philosophy of &ldquo;it&rsquo;s complicated, so it must be better!&rdquo; that plagues late-stage codebases, except that LangChain isn&rsquo;t even a year old. The effort needed to hack LangChain to do what I want it to do would cause insane amounts of <a href="https://en.wikipedia.org/wiki/Technical_debt">technical debt</a>. And unlike AI startups nowadays, technical debt for my own projects with LangChain can&rsquo;t be paid with venture capital. API wrappers should at minimum reduce code complexity and cognitive load when operating with complex ecosystems because it takes enough mental brainpower to work with AI itself. LangChain is one of the few pieces of software that <em>increases</em> overhead in most of its popular use cases.</p>
<p>I came to the conclusion that it&rsquo;s just easier to make my own Python package than it is to hack LangChain to fit my needs. Therefore, I developed and open-sourced <a href="https://github.com/minimaxir/simpleaichat">simpleaichat</a>: a Python package for easily interfacing with chat apps, emphasizing minimal code complexity and decoupling advanced features like vector stores from the conversation logic to avoid LangChain&rsquo;s lock-in, and many other features which would take its own blog post to elaborate upon.</p>
<p>But this blog post wasn&rsquo;t written to be a stealth advertisement for simpleaichat by tearing down a competitor like what hustlers do. I didn&rsquo;t <em>want</em> to make simpleaichat: I&rsquo;d rather spend my time creating more cool projects with AI, and it&rsquo;s a shame I could not have done that with LangChain. I know someone will say &ldquo;why not submit a pull request to the LangChain repo since it&rsquo;s open source instead of complaining about it?&rdquo; but most of my complaints are fundamental issues with the LangChain library and can&rsquo;t be changed without breaking everything for its existing users. The only real fix is to burn it all down and start fresh, which is why my &ldquo;create a new Python library for interfacing with AI&rdquo; solution is also the most pragmatic.</p>
<p>I&rsquo;ve gotten many messages asking me &ldquo;what should I learn to get started with the ChatGPT API&rdquo; and I&rsquo;m concerned that they&rsquo;ll go to LangChain first because of the hype. If machine learning engineers who do have backgrounds in the technology stack have difficulty using LangChain due to its needless complexity, any beginner is going to drown.</p>
<p>No one wants to be that asshole who criticizes free and open source software operating in good faith like LangChain, but I&rsquo;ll take the burden. To be clear, I have nothing against Harrison Chase or the other maintainers of LangChain (who encourage feedback!). However, LangChain&rsquo;s popularity has warped the AI startup ecosystem around LangChain itself and the hope of OMG <a href="https://en.wikipedia.org/wiki/Artificial_general_intelligence">AGI</a> I MADE SKYNET, which is why I am compelled to be honest with my misgivings about it.</p>
<p>Wars about software complexity and popularity despite its complexity are an eternal recurrence. In the 2010&rsquo;s, it was with <a href="https://react.dev">React</a>; in 2023, it&rsquo;s with ReAct.</p>
<hr>
<p><em>Jupyter Notebooks for the <a href="https://github.com/minimaxir/langchain-problems/blob/main/openai_rewrite.ipynb">simple implementations of LangChain examples</a> and the <a href="https://github.com/minimaxir/langchain-problems/blob/main/langchain_problems.ipynb">LangChain failure demo</a> are available in <a href="https://github.com/minimaxir/langchain-problems/tree/main">this GitHub repository</a>.</em></p>
]]></content:encoded>
    </item>
    <item>
      <title>How to Create a Blog Post Title Optimizer with GPT-3 and Hacker News Data</title>
      <link>https://minimaxir.com/2022/08/gpt3-blog-title-optimizer/</link>
      <pubDate>Mon, 15 Aug 2022 08:30:00 -0700</pubDate>
      <guid>https://minimaxir.com/2022/08/gpt3-blog-title-optimizer/</guid>
      <description>GPT-3 says the title for this post is very bad.</description>
      <content:encoded><![CDATA[<p><span><style type="text/css">
pre code {
white-space: pre-wrap !important;
}
</style></span></p>
<p>I am objectively terrible at writing attractive titles for my blog posts. Which is a problem, as nowadays it&rsquo;s a commonly understood truth that a good headline can be the sole factor whether a blog post goes viral or gets completely ignored, especially in the data science/machine learning fields I typically write about.</p>
<p>So, why not use said data science/machine learning to create an optimized title for me?</p>
<p>Many know <a href="https://openai.com/api/">GPT-3</a> as a tool for robust text generation. But a newer, lesser discussed feature that <a href="https://openai.com">OpenAI</a> allows is finetuning GPT-3 on data you provide. If I provide GPT-3 with a large dataset of good titles, can I use that to tell me if one of my blog post titles are good? Let&rsquo;s give it a try.</p>
<h2 id="getting-the-good-blog-post-data-from-hacker-news">Getting The Good Blog Post Data from Hacker News</h2>
<p><em>All code and tools used in this blog post are available <a href="https://github.com/minimaxir/gpt3-blog-title-optimizer">open-source on GitHub</a>.</em></p>
<p>The AI classifier I will create will be a <a href="https://en.wikipedia.org/wiki/Binary_classification">binary classifier</a>, which returns the probability that an input blog post title is good, and from that I can provide alternate blog post titles and see roughly which is best from those probabilities.</p>
<p>In order to finetune GPT-3 for this use case, I need to obtain a decently large amount of post titles with <code>good</code> and <code>bad</code> labels. For this experiment, I&rsquo;ll use submission data from <a href="https://news.ycombinator.com">Hacker News</a>.</p>
<figure>

    <img loading="lazy" srcset="/2022/08/gpt3-blog-title-optimizer/hn_front_page_hu_9925f5c07495cf67.webp 320w,/2022/08/gpt3-blog-title-optimizer/hn_front_page_hu_becc7f094ce1feae.webp 768w,/2022/08/gpt3-blog-title-optimizer/hn_front_page_hu_c2d62ab7441c47dc.webp 1024w,/2022/08/gpt3-blog-title-optimizer/hn_front_page.png 1312w" src="hn_front_page.png"
         alt="Hacker News frontpage on August 14th, 2022."/> <figcaption>
            <p>Hacker News frontpage on August 14th, 2022.</p>
        </figcaption>
</figure>

<p>Hacker News data is good for a few reasons: each submission has community validation by a large number of people, submission titles cover a wide variety of idiosyncratic styles, and most of all, it&rsquo;s easy to get Hacker News submission data in bulk <a href="https://console.cloud.google.com/marketplace/details/y-combinator/hacker-news">from BigQuery</a>. For example, if I wanted to get all submissions between August 2020 and 2022 with atleast a score of 10 (the rough minimum to get on the front page and to filter out some spam), plus some light filters to remove things that are definitely not blog posts or articles (such as <a href="https://news.ycombinator.com/show">Show HNs</a> and social media), I&rsquo;d input a SQL query something like this:</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-sql" data-lang="sql"><span class="line"><span class="cl"><span class="k">SELECT</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="n">title</span><span class="p">,</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="n">score</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="k">FROM</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="o">`</span><span class="n">bigquery</span><span class="o">-</span><span class="k">public</span><span class="o">-</span><span class="k">data</span><span class="p">.</span><span class="n">hacker_news</span><span class="p">.</span><span class="k">full</span><span class="o">`</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="k">WHERE</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="k">type</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s2">&#34;story&#34;</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="k">AND</span><span class="w"> </span><span class="n">score</span><span class="w"> </span><span class="o">&gt;=</span><span class="w"> </span><span class="mi">10</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="k">AND</span><span class="w"> </span><span class="n">url</span><span class="w"> </span><span class="k">IS</span><span class="w"> </span><span class="k">NOT</span><span class="w"> </span><span class="k">NULL</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="k">AND</span><span class="w"> </span><span class="k">timestamp</span><span class="w"> </span><span class="k">BETWEEN</span><span class="w"> </span><span class="s2">&#34;2020-08-01&#34;</span><span class="w"> </span><span class="k">AND</span><span class="w"> </span><span class="s2">&#34;2022-08-01&#34;</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="k">AND</span><span class="w"> </span><span class="k">NOT</span><span class="w"> </span><span class="n">REGEXP_CONTAINS</span><span class="p">(</span><span class="n">title</span><span class="p">,</span><span class="w"> </span><span class="n">r</span><span class="s2">&#34;^Show HN&#34;</span><span class="p">)</span><span class="w">
</span></span></span><span class="line"><span class="cl"><span class="w">  </span><span class="k">AND</span><span class="w"> </span><span class="k">NOT</span><span class="w"> </span><span class="n">REGEXP_CONTAINS</span><span class="p">(</span><span class="n">url</span><span class="p">,</span><span class="w"> </span><span class="n">r</span><span class="s2">&#34;(?:github|youtube|twitter)\.com&#34;</span><span class="p">)</span><span class="w">
</span></span></span></code></pre></div><p>This query returns roughly 90k submission titles total. For <code>good</code> titles, let&rsquo;s say we consider posts with atleast 100 points as &ldquo;good&rdquo;, because it&rsquo;s a nice number which is sometimes all that&rsquo;s necessary in the world of data science. There are about 27k posts with more than 100 points in that subset, which is more than sufficient. The harder part is selecting the <code>bad</code> titles: since there are 63k titles fewer than 100 points, the data set as-is is unbalanced ~1:3 and will lead to flawed training results.</p>
<p>There are two solutions: either repeat the <code>good</code> posts to roughly equal the <code>bad</code> posts, or take a subset of <code>bad</code> posts to roughly equal the amount of <code>good</code> posts. We&rsquo;ll do the latter since the sample size of <code>good</code> posts is large enough. Most people would download all 90k rows into something like Python to handle that sampling, but with SQL shenanigans you can do it entirely in BigQuery. (the annotated query <a href="https://github.com/minimaxir/gpt3-blog-title-optimizer/blob/main/gpt3_input_data.sql">is here</a> and out of scope for this post, but may be interesting for data science hiring managers who want to annoy candidates in screening interviews)</p>
<p>This results in a ~55k title dataset: 27k <code>good</code>, 27k <code>bad</code>, perfectly balanced, as all datasets should be.</p>
<p>OpenAI&rsquo;s <a href="https://beta.openai.com/docs/guides/fine-tuning">finetuning API</a> takes in a JSONL file where each line is a JSON object with two fields: <code>prompt</code> and <code>completion</code> (no, I am not sure why it can&rsquo;t just be a CSV). In this case, the <code>prompt</code> is the title, prepended with <code>Title: </code> and with a <code> -&gt;</code> suffix per their documentation suggestions to &ldquo;align&rdquo; it better to GPT-3, and the <code>completion</code> is the good/bad label, prepended with a space because GPT-3 is weird like that. An example of the final dataset:</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-json" data-lang="json"><span class="line"><span class="cl"><span class="p">{</span><span class="nt">&#34;prompt&#34;</span><span class="p">:</span><span class="s2">&#34;Title: How to slightly improve your life without trying -&gt;&#34;</span><span class="p">,</span><span class="nt">&#34;completion&#34;</span><span class="p">:</span><span class="s2">&#34; bad&#34;</span><span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">{</span><span class="nt">&#34;prompt&#34;</span><span class="p">:</span><span class="s2">&#34;Title: SixtyFPS Becomes Slint -&gt;&#34;</span><span class="p">,</span><span class="nt">&#34;completion&#34;</span><span class="p">:</span><span class="s2">&#34; bad&#34;</span><span class="p">}</span>
</span></span><span class="line"><span class="cl"><span class="p">{</span><span class="nt">&#34;prompt&#34;</span><span class="p">:</span><span class="s2">&#34;Title: Family estrangement: Why adults are cutting off their parents -&gt;&#34;</span><span class="p">,</span><span class="nt">&#34;completion&#34;</span><span class="p">:</span><span class="s2">&#34; bad&#34;</span><span class="p">}</span>
</span></span></code></pre></div><p>Their CLI cleans and can extract a validation set out of the inputs, which you should always do. Fortunately, BigQuery now offers JSONL export, so downloading the resulting dataset requires no further preprocessing. Once that&rsquo;s done, the CLI allows you finetune, with special options for binary classification. (the exact CLI command I used is <a href="https://github.com/minimaxir/gpt3-blog-title-optimizer/blob/main/finetune_command.txt">here</a>)</p>
<p>Another understated aspect of GPT-3 is that there are weaker models that are faster and much cheaper than the default <code>davinci</code> model that is what people use when they generally use &ldquo;GPT-3&rdquo;. For text generation they tend to have less coherent outputs, but for a simplified use case like binary classification they are more than sufficient. I&rsquo;ll use the <code>babbage</code> model, the second weakest.</p>
<p>The <a href="https://github.com/minimaxir/gpt3-blog-title-optimizer/blob/main/finetune_results.csv">final results</a> of the finetuning are about 63% accuracy on both the training and validation sets: not too much better than the default 50% accuracy of a balanced dataset for a binary classification problem, but given the problem difficulty it&rsquo;s better than <a href="https://minimaxir.com/2018/09/modeling-link-aggregators/">most approaches I&rsquo;ve done</a> for Hacker News data.</p>
<p>Once the finetuning is complete, you can query it, and ask it to return the probability of the returned token. Let&rsquo;s pass in the title for my last blog post: <strong><a href="https://minimaxir.com/2022/07/food-photography-ai/">Absurd AI-Generated Professional Food Photography with DALL-E 2</a></strong></p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-json" data-lang="json"><span class="line"><span class="cl"><span class="s2">&#34;top_logprobs&#34;</span><span class="err">:</span> <span class="p">[</span>
</span></span><span class="line"><span class="cl">  <span class="p">{</span>
</span></span><span class="line"><span class="cl">    <span class="nt">&#34; bad&#34;</span><span class="p">:</span> <span class="mf">-0.34654787</span>
</span></span><span class="line"><span class="cl">  <span class="p">}</span>
</span></span></code></pre></div><p>Well, that&rsquo;s not promising.</p>
<p>For some <em>really</em> weird reason, the API returns a log-probability instead of the actual probability that you&rsquo;d want, so taking the <a href="https://en.wikipedia.org/wiki/Exponential_function">exponent</a> of that value results in a 70.7% probability it&rsquo;s bad, which means there&rsquo;s a 29.3% chance it&rsquo;s good.</p>
<p>And <em>that</em>, is why I need a title optimizer.</p>
<h2 id="using-instructgpt-to-create-alternate-titles">Using InstructGPT To Create Alternate Titles</h2>
<p>Since we now have a tool to determine the quality of blog post titles, how do we generate alternate titles that maintain the same meaning? I could think of tweaks to titles, but that takes <em>effort</em> and I am lazy. What if GPT-3 could create the candidate titles for me? It turns out, GPT-3 latest Instruct model can.</p>
<p>InstructGPT, <a href="https://openai.com/blog/instruction-following/">released in January</a> without much fanfare, is a version of <code>davinci</code> OpenAI finetuned themselves to better respond to instructions. It worked so well that it&rsquo;s now the default GPT-3 model (noted as <code>text-davinci-002</code> in the backend UI).</p>
<p>InstructGPT is surprisingly robust with the right prompt engineering. You can tell it to create detailed product descriptions of <a href="https://twitter.com/minimaxir/status/1551609670237708288">nonexistent video games</a>, or write <a href="https://twitter.com/minimaxir/status/1536824548376465409">4chan-style greentexts</a> for any domain which maintain both the style and twist endings of the format.</p>
<figure>

    <img loading="lazy" srcset="/2022/08/gpt3-blog-title-optimizer/greentext_hu_56587d1b774f4459.webp 320w,/2022/08/gpt3-blog-title-optimizer/greentext.png 684w" src="greentext.png"
         alt="via OpenAI&rsquo;s GPT-3 Playground; all nonhighlighted text is the prompt."/> <figcaption>
            <p>via OpenAI&rsquo;s GPT-3 Playground; all nonhighlighted text is the prompt.</p>
        </figcaption>
</figure>

<p>After a bit of testing, the prompt I found worked best for this use case was:</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-txt" data-lang="txt"><span class="line"><span class="cl">Rewrite the following blog post title into six different titles but optimized for social media virality: &lt;FILL IN TITLE&gt;
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">-
</span></span></code></pre></div><p>It&rsquo;s verbose, but that&rsquo;s prompt engineering for you. The <code>-</code> at the end informs GPT-3 that the output should be a list with dash-bullets, which will make it easier to programmatically split the final output into distinct titles.</p>
<p>You can test it on the <a href="https://beta.openai.com/playground">GPT-3 Playground</a>; if the <code>temperature</code> parameter is <code>0</code>, then the output will be deterministic.</p>
<p>Again putting in my last blog post <strong><a href="https://minimaxir.com/2022/07/food-photography-ai/">Absurd AI-Generated Professional Food Photography with DALL-E 2</a></strong> into InstructGPT:</p>
<figure>

    <img loading="lazy" srcset="/2022/08/gpt3-blog-title-optimizer/gpt3_demo_hu_27d5e6dccb947f02.webp 320w,/2022/08/gpt3-blog-title-optimizer/gpt3_demo_hu_88e1b7cc78964c46.webp 768w,/2022/08/gpt3-blog-title-optimizer/gpt3_demo_hu_2e274ce89a10aaac.webp 1024w,/2022/08/gpt3-blog-title-optimizer/gpt3_demo.png 1478w" src="gpt3_demo.png"
         alt="via OpenAI&rsquo;s GPT-3 Playground; all nonhighlighted text is the prompt."/> <figcaption>
            <p>via OpenAI&rsquo;s GPT-3 Playground; all nonhighlighted text is the prompt.</p>
        </figcaption>
</figure>

<p>All six of those titles are definitely an improvement, and all the text in green is what the programmatic API returns. Notably, despite the terseness of the input title and recency of DALL-E 2, InstructGPT is able to infer that the AI <em>creates</em> something and work from that, which is impressive.</p>
<h2 id="put-the-title-optimizer-into-action">Put The Title Optimizer Into Action!</h2>
<p><em>A walkthrough of the code used to interact with the GPT-3 API and make the optimizer is available <a href="https://github.com/minimaxir/gpt3-blog-title-optimizer/blob/main/gpt3_title_optimizer_walkthrough.ipynb">in this Jupyter Notebook</a>, and the final demos are available <a href="https://github.com/minimaxir/gpt3-blog-title-optimizer/blob/main/gpt3_title_optimizer_demo.ipynb">in this Notebook</a>.</em></p>
<p>Now that we have the two models ready, the workflow is simple:</p>
<ul>
<li>Choose the title of a technical blog post I want to optimize.</li>
<li>Ping InstructGPT to get up to 6 alternate titles.</li>
<li>Extract/clean up the generated titles (i.e. split and strip whitespace)</li>
<li>For each of those alternate titles, ping the finetuned Hacker News GPT-3 for the probability that it is a <code>good</code> title.</li>
<li>In a pretty table, sort the titles by probability, descending.</li>
</ul>
<p>Because the model can&rsquo;t be widely distributed without review due to OpenAI rules, I decided to put the &ldquo;UI&rdquo; for this into a personal Jupyter Notebook.</p>
<p>Let&rsquo;s experiment! We know the title of <strong><a href="https://minimaxir.com/2022/07/food-photography-ai/">Absurd AI-Generated Professional Food Photography with DALL-E 2</a></strong> is bad and the alternatives are interesting, but how good are the alternatives?</p>
<figure>

    <img loading="lazy" srcset="/2022/08/gpt3-blog-title-optimizer/dalle_hu_699e00a90cc5bfc7.webp 320w,/2022/08/gpt3-blog-title-optimizer/dalle_hu_11191d9ffc307102.webp 768w,/2022/08/gpt3-blog-title-optimizer/dalle_hu_e93a795c94810124.webp 1024w,/2022/08/gpt3-blog-title-optimizer/dalle.png 1324w" src="dalle.png"
         alt="via GPT-3 Title Optimizer"/> <figcaption>
            <p>via GPT-3 Title Optimizer</p>
        </figcaption>
</figure>

<p>Most of alternates are <em>much</em> better, with the predicted probabilities of being a good post going above 50%. (I probably should change the title retroactively but I will live with my SEO dishonor)</p>
<p>The original title for this post, in my boring no-one-will-ever-click-this style, was <strong>Creating a Blog Post Title Optimizer by Finetuning GPT-3 on Hacker News</strong>. Let&rsquo;s plop it into the optimizer:</p>
<figure>

    <img loading="lazy" srcset="/2022/08/gpt3-blog-title-optimizer/temp0_hu_c45a69584ea12fd7.webp 320w,/2022/08/gpt3-blog-title-optimizer/temp0_hu_58abecd2066b41fe.webp 768w,/2022/08/gpt3-blog-title-optimizer/temp0_hu_c0a2a3331086ec1.webp 1024w,/2022/08/gpt3-blog-title-optimizer/temp0.png 1266w" src="temp0.png"
         alt="via GPT-3 Title Optimizer, temperature=0"/> <figcaption>
            <p>via GPT-3 Title Optimizer, <code>temperature=0</code></p>
        </figcaption>
</figure>

<p>So yes, the optimizer says the original title is very bad. But in this case, the variants are clickbaity and probably wouldn&rsquo;t do very well on Hacker News.</p>
<p>Fortunately, you can rerun the generation and get more different variants if <code>temperature</code> is nonzero.</p>
<figure>

    <img loading="lazy" srcset="/2022/08/gpt3-blog-title-optimizer/temp_0_7_hu_593f048f9b7f6c5f.webp 320w,/2022/08/gpt3-blog-title-optimizer/temp_0_7_hu_585270e1edaa9949.webp 768w,/2022/08/gpt3-blog-title-optimizer/temp_0_7_hu_4911b1029897a8ad.webp 1024w,/2022/08/gpt3-blog-title-optimizer/temp_0_7.png 1276w" src="temp_0_7.png"
         alt="via GPT-3 Title Optimizer, temperature=0.7"/> <figcaption>
            <p>via GPT-3 Title Optimizer, <code>temperature=0.7</code></p>
        </figcaption>
</figure>

<figure>

    <img loading="lazy" srcset="/2022/08/gpt3-blog-title-optimizer/temp_1_0_hu_c796ae615316a9f7.webp 320w,/2022/08/gpt3-blog-title-optimizer/temp_1_0_hu_73c96d490a895ee1.webp 768w,/2022/08/gpt3-blog-title-optimizer/temp_1_0_hu_11a99b16ef025c5.webp 1024w,/2022/08/gpt3-blog-title-optimizer/temp_1_0.png 1246w" src="temp_1_0.png"
         alt="via GPT-3 Title Optimizer, temperature=1.0"/> <figcaption>
            <p>via GPT-3 Title Optimizer, <code>temperature=1.0</code></p>
        </figcaption>
</figure>

<p>Definitely more variety. I like &ldquo;How to Create a Blog Post Title Optimizer with GPT-3&rdquo; as it maintains the same spirit even if it&rsquo;s not the most optimal, although for disclosure reasons, I do want to include Hacker News somewhere in the title. Therefore, I can tweak the input to &ldquo;How to Create a Blog Post Title Optimizer with GPT-3 and Hacker News Data&rdquo; and feed it back to the optimizer and maybe get an interative improvement.</p>
<figure>

    <img loading="lazy" srcset="/2022/08/gpt3-blog-title-optimizer/final_variant_hu_21b52227f7a46793.webp 320w,/2022/08/gpt3-blog-title-optimizer/final_variant_hu_ac4d8c781b4ab4be.webp 768w,/2022/08/gpt3-blog-title-optimizer/final_variant_hu_5671c5b32b4bcdb9.webp 1024w,/2022/08/gpt3-blog-title-optimizer/final_variant.png 1328w" src="final_variant.png"
         alt="via GPT-3 Title Optimizer"/> <figcaption>
            <p>via GPT-3 Title Optimizer</p>
        </figcaption>
</figure>

<p>The probability went down significantly with the change, and none of the variants are much better. Oh well.</p>
<p>Here&rsquo;s the results of running the optimizer for some of my older blog posts:</p>
<figure>

    <img loading="lazy" srcset="/2022/08/gpt3-blog-title-optimizer/moocs_hu_76ce8e5ea2a2036b.webp 320w,/2022/08/gpt3-blog-title-optimizer/moocs_hu_b8f0726fdef3a57c.webp 768w,/2022/08/gpt3-blog-title-optimizer/moocs_hu_ff658cab6ad7c272.webp 1024w,/2022/08/gpt3-blog-title-optimizer/moocs.png 1390w" src="moocs.png"
         alt="The results for this post are indeed better; I&rsquo;d definitely click the top one although it&rsquo;s misleading."/> <figcaption>
            <p>The results for <a href="https://minimaxir.com/2018/10/data-science-protips/">this post</a> are indeed better; I&rsquo;d definitely click the top one although it&rsquo;s misleading.</p>
        </figcaption>
</figure>

<figure>

    <img loading="lazy" srcset="/2022/08/gpt3-blog-title-optimizer/ncaa_hu_f9c3d807b6e6523e.webp 320w,/2022/08/gpt3-blog-title-optimizer/ncaa_hu_8755fbee1b6a25d4.webp 768w,/2022/08/gpt3-blog-title-optimizer/ncaa_hu_37c95655b09516ef.webp 1024w,/2022/08/gpt3-blog-title-optimizer/ncaa.png 1372w" src="ncaa.png"
         alt="The results for this post are much better, although this is one case where the original title is actually good."/> <figcaption>
            <p>The results for <a href="https://minimaxir.com/2018/03/basketball-shots/">this post</a> are much better, although this is one case where the original title is actually good.</p>
        </figcaption>
</figure>

<figure>

    <img loading="lazy" srcset="/2022/08/gpt3-blog-title-optimizer/pc_game_hu_78a640ce719c4399.webp 320w,/2022/08/gpt3-blog-title-optimizer/pc_game_hu_fcd70b2118389481.webp 768w,/2022/08/gpt3-blog-title-optimizer/pc_game_hu_d706e98685168754.webp 1024w,/2022/08/gpt3-blog-title-optimizer/pc_game.png 1246w" src="pc_game.png"
         alt="The results for this post are a balance between better and not-technically-misleading clickbait."/> <figcaption>
            <p>The results for <a href="https://minimaxir.com/2013/06/working-as-intended/">this post</a> are a balance between better and not-<em>technically</em>-misleading clickbait.</p>
        </figcaption>
</figure>

<p>Costwise, the entire pipeline is relatively inexpensive. Overall, it&rsquo;s about $0.02 per run: too expensive to give unrestricted access to the internet, but very high return-on-investment if it successfully results in a catchy headline even if takes multiple tries. The most expensive part was the finetuning itself, which cost $2 but is a one-time cost.</p>
<p>Some might ask &ldquo;why finetune GPT-3 when you can finetune an open-source large language model such as <a href="https://en.wikipedia.org/wiki/BERT_%28language_model%29">BERT</a> like every NLP project since 2018?&rdquo; In this case, GPT-3&rsquo;s advantage is that it was trained in the entire internet. GPT-3 is a master of idiosyncrasy, which is a key when working with Hacker News data and in theory would give better results than the Wikipedia-trained BERT. The success of Hacker News posts also depends on a global context outside of the title itself, which is why finetuning an existing model trained on such context may be better than training an existing model solely on HN data.</p>
<p>Some are concerned about GPT-3 and AI tools such as these making writers redundant, but the results here prove otherwise: there will always have to be a human-in-the-loop.</p>
<hr>
<p>UPDATE: When I <a href="https://news.ycombinator.com/item?id=32471208">submitted this post</a> to Hacker News, it ended up getting <em>over 200 points</em>, defying the 20.8% probability!</p>
]]></content:encoded>
    </item>
    <item>
      <title>Things About Real-World Data Science Not Discussed In MOOCs and Thought Pieces</title>
      <link>https://minimaxir.com/2018/10/data-science-protips/</link>
      <pubDate>Mon, 22 Oct 2018 09:15:00 -0700</pubDate>
      <guid>https://minimaxir.com/2018/10/data-science-protips/</guid>
      <description>MOOCs and thought pieces overfit to a certain style of data science that is not robust to the vast uncertainties of the real world.</description>
      <content:encoded><![CDATA[<p><a href="https://en.wikipedia.org/wiki/Data_science">Data science</a> has been sweeping the tech world. With a large variety of powerful free open-sourced tools and now the computing power to utilize them to their full potential, data science is more accessible than ever and has become <a href="https://www.bloomberg.com/news/articles/2018-05-18/-sexiest-job-ignites-talent-wars-as-demand-for-data-geeks-soars">America&rsquo;s hottest job</a>. One problem: there&rsquo;s no consensus on <a href="https://hbr.org/2018/08/what-data-scientists-really-do-according-to-35-data-scientists">what data scientists <em>really</em> do</a> in a professional setting.</p>
<p>There has been a rise in <em>romantic</em> thought pieces lately (especially on <a href="https://medium.com">Medium</a>) about how data scientists are wizards and can solve any problem (with bonus points if it cites AI). If you follow publications like <a href="https://towardsdatascience.com">Towards Data Science</a>, you&rsquo;ll notice persistent tropes in the more code-oriented posts: Python is the king programming language for data science, use <a href="http://scikit-learn.org/stable/">scikit-learn</a>/<a href="https://xgboost.readthedocs.io/en/latest/">XGBoost</a> and logistic regression for predicting categorical variable(s), use <a href="https://pandas.pydata.org">pandas</a> for processing tabular data, use <a href="https://www.nltk.org">NLTK</a>/<a href="https://en.wikipedia.org/wiki/Word2vec">word2vec</a> for processing text data, use <a href="https://www.tensorflow.org">TensorFlow</a>/<a href="https://keras.io">Keras</a>/convolutional neural networks for processing image data, use <a href="https://en.wikipedia.org/wiki/K-means_clustering"><em>k</em>-means</a> for clustering data, split the processed dataset into training and test datasets for model training, tweak hyperparameters/model features <a href="https://xkcd.com/1838/">until results on the test dataset are good</a>, etc.</p>
<figure>

    <img loading="lazy" srcset="/2018/10/data-science-protips/thought_hu_a119caa2480267cc.webp 320w,/2018/10/data-science-protips/thought.png 397w" src="thought.png"/> 
</figure>

<p>These tropes aren&rsquo;t inappropriate or misleading, but the analysis often doesn&rsquo;t quantify the insight/value of the results. Modeling is just one small part (and often the <em>easiest</em> part) of a very complex system.</p>
<p>Data-oriented MOOCs (<a href="https://en.wikipedia.org/wiki/Massive_open_online_course">Massive Online Open Courses</a>) like Andrew Ng&rsquo;s <a href="https://www.coursera.org/learn/machine-learning">Coursera course on Machine Learning</a> and <a href="http://course.fast.ai">fast.ai&rsquo;s course on Deep Learning</a> are good academic introductions to the theory and terminology behind data science and other related fields. Although MOOCs have many practice problems for prospective data scientists to solve, they don&rsquo;t make you an expert in the field capable of handling messier real-world problems, nor claim to do so.</p>
<p>Modern data science isn&rsquo;t about burying your head in a <a href="http://jupyter.org">Jupyter Notebook</a> and staring at the screen watching training loss numbers trickle down (although it&rsquo;s definitely fun!). There&rsquo;s a lot more to it, some of which I&rsquo;ve learned firsthand working as a Data Scientist at <a href="https://www.buzzfeed.com">BuzzFeed</a> for over a year. To borrow a statistical term, MOOCs and thought pieces <em>overfit</em> to a certain style of data science that is not robust to the vast uncertainties of the real world.</p>
<h2 id="the-costbenefit-tradeoffs-of-data-science">The Cost/Benefit Tradeoffs of Data Science</h2>
<p>Data science often follows the <a href="https://en.wikipedia.org/wiki/Pareto_principle">Pareto principle</a>: 80% of the work takes 20% of the effort. Thought pieces demonstrate that you can just toss data indiscriminately into scikit-learn or a deep learning framework and get neat-looking results. The value of a data scientist, however, is when and <em>if</em> to further development on a model.</p>
<p><a href="https://www.kaggle.com/competitions">Kaggle competitions</a> are a popular and often-recommended way to get exposure to real-world data science problems. Many teams of statisticians compete to create the best model for a given dataset (where &ldquo;best&rdquo; usually means minimizing the predictive loss/error of the model), with prizes for the highest-performing models. Kaggle also encourages clever modeling techniques such as <a href="http://scikit-learn.org/stable/modules/grid_search.html">grid search</a> of thousands of model hyperparameter combinations and ensembling disparate models to create a megamodel which results in <em>slightly</em> better predictive performance, but just might give the edge to win.</p>
<p>However, there are a few important differences between modeling in a Kaggle competition and modeling in a data science team. Kaggle competitions last for <em>weeks</em> when a professional data scientist may need to spend time on other things. Ensembling gigantic machine learning models makes predictions very slow and the models themselves very large; both of which may cause difficulty deploying them into production (e.g. the <a href="https://www.wired.com/2012/04/netflix-prize-costs/">Netflix Prize</a> movie recommendation models famously &ldquo;did not seem to justify the engineering effort needed to bring them into a production environment&rdquo;). And most importantly, there may not be a significant <em>practical</em> performance difference between a 1st place Kaggle model that takes days/weeks to optimize and a simple scikit-learn/XGBoost baseline that can be built in a few hours.</p>
<p>Counterintuitively, it may be better to trade performance for speed/memory with a weaker-but-faster model; in business cases, speed and scalability are important implementation constraints. But even with scikit-learn, the model is still a <a href="https://en.wikipedia.org/wiki/Black_box">black box</a> with little idea to the data scientist how the model makes its decisions. One final option is to go back to basics altogether with a &ldquo;boring&rdquo; linear/logistic regression model, where the predictive performance may be even weaker and the model <a href="http://statisticsbyjim.com/regression/ols-linear-regression-assumptions/">must follow several statistical assumptions</a>, but the model feature coefficients and statistical significance <a href="http://blog.minitab.com/blog/adventures-in-statistics-2/how-to-interpret-regression-analysis-results-p-values-and-coefficients">are easily interpretable</a> to explain the importance of each input feature (if any) and make actionable, informed decisions for the business. Being a data scientist requires making educated judgments about these tradeoffs.</p>
<h2 id="data-scientists-still-use-business-intelligence-tools">Data Scientists Still Use Business Intelligence Tools</h2>
<p>A hobbyist data scientist without a budget may opt to build their own workflows and data pipelines using free tools. However, professional data scientists have a finite amount of free time (as do all engineers), so there&rsquo;s a massive opportunity cost when reinventing the wheel unnecessarily. Enterprise BI tools such as <a href="https://www.tableau.com">Tableau</a>, <a href="https://looker.com">Looker</a>, and <a href="https://modeanalytics.com">Mode Analytics</a> help retrieve and present data with easy-to-digest dashboards for anyone in the company. They&rsquo;re never cheap, but they&rsquo;re much cheaper to the company than having a data scientist spend valuable time to develop and maintain similar tooling over time.</p>
<p>If a stakeholder wants a data report ASAP, there&rsquo;s no problem falling back to using <a href="https://en.wikipedia.org/wiki/SQL">SQL</a> to query a data warehouse and output results into an Excel spreadsheet (plus pretty data visualizations!) to quickly transport in an email. Part of being a data scientist is working out which tools are best appropriate at what time.</p>
<p>Some might argue that using BI tools and SQL are not responsibilities for data scientists, but instead for Business Analysts or Data Analysts. That&rsquo;s a <a href="https://en.wikipedia.org/wiki/No_true_Scotsman">No True Scotsman</a> way of looking at it; there&rsquo;s a lot of overlap in data science with other analytical fields, and there&rsquo;s nothing wrong with that.</p>
<h2 id="data-scientists-are-software-engineers-too">Data Scientists Are Software Engineers Too</h2>
<p>Although MOOCs encourage <em>self</em>-study, data science is a collaborative process. And not just with other data scientists on a team, but with other software engineers in the company. Version control tools like <a href="https://git-scm.com">Git</a> are often used for data scientists to upload their portfolio projects publicly to <a href="https://github.com">GitHub</a>, but there are many other important features for use in a company-wide collaborative environment such as branching a repository, making pull requests, and merging conflicts. Beyond that are modern development QA practices, such as test environments, consistent code style, and code reviews. The full process varies strongly by company: Airbnb has a <a href="https://medium.com/airbnb-engineering/scaling-knowledge-at-airbnb-875d73eff091">good thought piece</a> about how they utilize their Knowledge Base for data science collaboration using Git.</p>
<p>One of the very hard and surprisingly underdiscussed aspects of data science is <a href="https://en.wikipedia.org/wiki/DevOps">DevOps</a>, and how to actually get a statistical model into production. <a href="https://www.docker.com/resources/what-container">Docker containers</a>, for example, are newer technology that&rsquo;s hard to learn, but have many data science and DevOps benefits by mitigating Python dependency hell and ensuring a consistent environment for model deployment and execution. And once the model is in production, data scientists, data engineers, and dedicated DevOps personnel need to work together to figure out if the model has the expected output, if the model is performing with expected speed/memory overhead, how often to retrain the model on fresh data (plus the scheduling/data pipelining necessary to do so), and how to efficiently route predictions out of the system to the user.</p>
<h2 id="data-science-cant-solve-everything">Data Science Can&rsquo;t Solve Everything</h2>
<p>Data science experiments (even those utilizing magical AI) are allowed to fail, and not just in the fail-to-reject-the-null-hypothesis sense. Thought pieces typically discuss successful projects, which leads to a survivorship bias. Even with massive amounts of input data, it&rsquo;s <em>likely</em> for a model to fail to converge and offer zero insight, or an experiment fail to offer statistically significant results (common with <a href="https://vwo.com/ab-testing/">A/B testing</a>).</p>
<p><span><blockquote class="twitter-tweet" data-lang="en"><p lang="en" dir="ltr">real world data science is an R<sup>2</sup> of 0.10 <a href="https://twitter.com/hashtag/GoogleNext18?src=hash&amp;ref_src=twsrc%5Etfw">#GoogleNext18</a> <a href="https://t.co/qNsno2dscR">pic.twitter.com/qNsno2dscR</a></p>— Max Woolf (@minimaxir) <a href="https://twitter.com/minimaxir/status/1021885939361042432?ref_src=twsrc%5Etfw">July 24, 2018</a></blockquote> <script async src="https://platform.twitter.com/widgets.js" charset="utf-8"></script>
</span></p>
<p>The difficulty of real-world data science is recognizing if a given problem <em>can</em> be solved, how much of your valuable time to spend iterating to <em>maybe</em> solve it, how to report to stakeholders if it <em>can&rsquo;t</em> be solved, and what are the next steps if that&rsquo;s the case.</p>
<p>Don&rsquo;t <a href="https://www.buzzfeednews.com/article/stephaniemlee/brian-wansink-cornell-p-hacking"><em>p</em>-hack</a>!</p>
<h2 id="data-science-and-ethics">Data Science and Ethics</h2>
<p>During the rise of the &ldquo;data science/AI is magic!&rdquo; era, massive algorithmic and statistical failures suggest that data science might not always make the world a better place. Amazon built a resume-reading model which <a href="https://www.reuters.com/article/us-amazon-com-jobs-automation-insight/amazon-scraps-secret-ai-recruiting-tool-that-showed-bias-against-women-idUSKCN1MK08G">accidentally learned to be sexist</a>. Facebook overestimated <a href="https://www.theverge.com/2018/10/17/17989712/facebook-inaccurate-video-metrics-inflation-lawsuit">performance metrics on their videos</a>, causing complete business pivots for media organizations in vain, indirectly <a href="https://www.theatlantic.com/technology/archive/2018/10/facebook-driven-video-push-may-have-cost-483-journalists-their-jobs/573403/">leading to hundreds of layoffs</a>. YouTube&rsquo;s recommended video algorithms <a href="https://medium.com/@jamesbridle/something-is-wrong-on-the-internet-c39c471271d2">drove children towards shocking and disturbing content</a>. And these companies have some of the best data talent <em>in the entire world</em>.</p>
<p>The <em>qualitative</em> output of a model or data analysis is just as important as the quantitative performance, if not more. Allowing dangerous model output to hit production and impact <em>millions</em> of consumers is a failure of QA at all levels. In fairness these companies usually fix these issues, but only <em>after</em> journalists <a href="https://www.nytimes.com/2018/10/19/opinion/facebook-twitter-journalism-misinformation.html">point them out</a>. The problem with blindly chasing a performance metric (like Kaggle) is that it ignores collateral, unexpected effects.</p>
<p><span><blockquote class="twitter-tweet" data-lang="en"><p lang="en" dir="ltr">Don’t be data-driven. Be data-informed. Metrics should never be in charge because they have no moral compass.</p>— Kim Goodwin (@kimgoodwin) <a href="https://twitter.com/kimgoodwin/status/1051849805280948224?ref_src=twsrc%5Etfw">October 15, 2018</a></blockquote> <script async src="https://platform.twitter.com/widgets.js" charset="utf-8"></script> </span></p>
<p>Maybe recommending shocking videos is what maximizes clickthrough rate or ad revenue per the models according to a business dashboard. Unfortunately, if the data justifies it and the business stakeholders encourage it, the company may <em>accept the consequences</em> of a flawed algorithm if they don&rsquo;t outweigh the benefits. It&rsquo;s important for data scientists to be aware that they may be party to that.</p>
<h2 id="conclusion">Conclusion</h2>
<p>I realize the irony of using a data science thought piece to argue against data science thought pieces. In fairness, some Medium thought pieces do apply data science in very <em>unique</em> ways or touch on very obscure-but-impactful aspects of frameworks, and I enjoy reading those. The field is still very broadly defined, and your experiences may differ from this post, especially if you&rsquo;re working for a more research-based institution. Unfortunately, I don’t have any new advice for <em>getting</em> a data science job, which is <a href="https://twitter.com/minimaxir/status/951117788835278848">still very difficult</a>.</p>
<p>The popular idea that being a data scientist is a 40-hours-a-week Kaggle competition is <strong>incorrect</strong>. There&rsquo;s a lot more to it that&rsquo;s not as sexy which, in my opinion, is the more interesting aspect of the data science field as a whole.</p>
]]></content:encoded>
    </item>
    <item>
      <title>How to Quickly Train a Text-Generating Neural Network for Free</title>
      <link>https://minimaxir.com/2018/05/text-neural-networks/</link>
      <pubDate>Fri, 18 May 2018 09:00:00 -0700</pubDate>
      <guid>https://minimaxir.com/2018/05/text-neural-networks/</guid>
      <description>Train your own text-generating neural network and generate text whenever you want with just a few clicks!</description>
      <content:encoded><![CDATA[<p>One of the more interesting applications of the neural network revolution is text generation. Most popular approaches are based off of Andrej Karpathy&rsquo;s <a href="https://github.com/karpathy/char-rnn">char-rnn architecture</a>/<a href="http://karpathy.github.io/2015/05/21/rnn-effectiveness/">blog post</a>, which teaches a recurrent neural network to be able to predict the next character in a sequence based on the previous <em>n</em> characters. As a result, a sufficiently trained network can theoretically reproduce its input source material, but since properly-trained neural networks aren&rsquo;t <em>perfect</em>, the output can fall into a weird-but-good uncanny valley.</p>
<p>Many internet tutorials for text-generation neural networks simply copy an existing char-rnn implementation while changing the input dataset. It&rsquo;s one approach, but there&rsquo;s an opportunity for improvement with modern deep learning tooling. Thanks to frameworks like <a href="https://www.tensorflow.org">TensorFlow</a> and <a href="https://github.com/keras-team/keras">Keras</a>, I built <a href="https://github.com/minimaxir/textgenrnn">textgenrnn</a>, a <a href="https://pypi.org/project/textgenrnn/#description">Python package</a> which abstracts the process of creating and training such char-rnns to a <em>few lines of code</em>, with numerous model architecture and training improvements such as <a href="http://minimaxir.com/2017/04/char-embeddings/">character embeddings</a>, attention-weighted averaging, and a decaying learning rate.</p>
<p>A neat benefit of textgenrnn is that it can be easily used to train neural networks on a GPU very quickly, <em>for free</em> using <a href="https://colab.research.google.com/notebooks/welcome.ipynb">Google Colaboratory</a>. I&rsquo;ve <a href="https://drive.google.com/file/d/1mMKGnVxirJnqDViH7BDJxFqWrsXlPSoK/view?usp=sharing">created a notebook</a> which lets you train your own network and generate text whenever you want with just a few clicks!</p>
<h2 id="your-first-text-generating-neural-network">Your First Text-Generating Neural Network</h2>
<p>Colaboratory is a notebook environment similar to <a href="http://jupyter.org">Jupyter Notebooks</a> used in other data science projects. However, Colaboratory notebooks are hosted in a short term virtual machine, with 2 vCPUs, 13GB memory, and a K80 GPU attached. For free. Normally, this configuration would <a href="https://cloud.google.com/compute/pricing">cost</a> $0.57/hr on Google Compute Engine; it sounds low, but adds up when you need to train model(s) for hours to get good results.</p>
<p>First, I recommend copying the notebook to your own Drive so it&rsquo;ll always be there (and switch to using Google Chrome if you aren&rsquo;t). The Colaboratory VM contains Python 3 and common Python packages for machine learning such as TensorFlow. But you can install more packages directly in the notebook. Like textgenrnn! Just run this cell by clicking into the cell and click the &ldquo;play&rdquo; button (or use Shift + Enter) and it&rsquo;ll take care of the rest:</p>
<figure>

    <img loading="lazy" srcset="/2018/05/text-neural-networks/pip_hu_2961e2ffcd246a46.webp 320w,/2018/05/text-neural-networks/pip_hu_897c5f1e48e49d.webp 768w,/2018/05/text-neural-networks/pip.png 936w" src="pip.png"/> 
</figure>

<p>When training a new model, textgenrnn allows you to specify the size and complexity of the neural network with a wide variety of parameters:</p>
<figure>

    <img loading="lazy" srcset="/2018/05/text-neural-networks/config_hu_f70243006edc1525.webp 320w,/2018/05/text-neural-networks/config.png 515w" src="config.png"/> 
</figure>

<p>Let&rsquo;s keep these default parameters for now, so run that cell to load them into memory. Run the next cell, which prompts you to upload a file. <em>Any text file should work</em>, even large text files! For this example, we&rsquo;ll use a 1.1MB text file of Shakespeare plays also <a href="https://github.com/karpathy/char-rnn/tree/master/data/tinyshakespeare">used in the char-rnn demos</a>.</p>
<figure>

    <img loading="lazy" srcset="/2018/05/text-neural-networks/upload_hu_82dc8aa0d091b8b8.webp 320w,/2018/05/text-neural-networks/upload_hu_b8978ae780ae5333.webp 768w,/2018/05/text-neural-networks/upload_hu_5121df32448d3012.webp 1024w,/2018/05/text-neural-networks/upload.png 1812w" src="upload.png"/> 
</figure>

<p>The next cell initializes an instance of textgenrnn and begins training a custom new text-generating neural network!</p>
<figure>

    <img loading="lazy" srcset="/2018/05/text-neural-networks/train_hu_afa90d7a652eb6c6.webp 320w,/2018/05/text-neural-networks/train_hu_876c5fbe08bb13ca.webp 768w,/2018/05/text-neural-networks/train_hu_cb2d12d9c68cd1a4.webp 1024w,/2018/05/text-neural-networks/train.png 1810w" src="train.png"/> 
</figure>

<p>textgenrnn automatically processes the input text into character sequences ready to train the network. After every 2 epochs (a full pass through the data), the network will generate sample text at different temperatures, which represent the &ldquo;creativity&rdquo; of the text (i.e. it allows the model to make increasingly suboptimal predictions, which can cause hilarity to ensue). I typically like generating text at a temperature of 0.5, but for very well-trained models, you can go up to 1.0.</p>
<p>The quick model training speed comes from the VM&rsquo;s GPU, which can perform the necessary mathematical operations much faster than with a CPU. However, in the case of recurrent neural networks, Keras recently added a <a href="https://keras.io/layers/recurrent/#cudnnlstm">CuDNN implementation of RNNs</a> like LSTMs, which can easily tap into the GPU-native code more easily and gain a <em>massive</em> speed boost (<a href="http://minimaxir.com/2017/11/benchmark-gpus/">about <em>7x as fast</em></a>) compared to previous implementations! In all, for this example dataset and model architecture, training on a GPU took 5-6 minutes an epoch, while on a modern CPU, training took <em>1 hour and 24 minutes</em> an epoch, a <strong>14x speedup</strong> on the GPU!</p>
<p>After training is complete, running the next cell will download three files: a <code>weights</code> file, a <code>vocabulary</code> file, and a <code>config</code> file that are all needed to regenerate your model elsewhere.</p>
<figure>

    <img loading="lazy" srcset="/2018/05/text-neural-networks/download_hu_e81503700cb98809.webp 320w,/2018/05/text-neural-networks/download_hu_e9891a5332f7cbcd.webp 768w,/2018/05/text-neural-networks/download_hu_cbcca1b791ab82ef.webp 1024w,/2018/05/text-neural-networks/download.png 1618w" src="download.png"/> 
</figure>

<p>For example, on your own personal computer. Just install textgenrnn + TensorFlow by inputting <code>pip3 install textgenrnn tensorflow</code> into a terminal, change to the directory where the downloaded files are located, run <code>python3</code>, and load the model using:</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="kn">from</span> <span class="nn">textgenrnn</span> <span class="kn">import</span> <span class="n">textgenrnn</span>
</span></span><span class="line"><span class="cl"><span class="n">textgen</span> <span class="o">=</span> <span class="n">textgenrnn</span><span class="p">(</span><span class="n">weights_path</span><span class="o">=</span><span class="s1">&#39;colaboratory_weights.hdf5&#39;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">                       <span class="n">vocab_path</span><span class="o">=</span><span class="s1">&#39;colaboratory_vocab.json&#39;</span><span class="p">,</span>
</span></span><span class="line"><span class="cl">                       <span class="n">config_path</span><span class="o">=</span><span class="s1">&#39;colaboratory_config.json&#39;</span><span class="p">)</span>
</span></span></code></pre></div><p>And that&rsquo;s that! No GPU necessary if you&rsquo;re just generating text. You can generate samples (like during training) using <code>textgen.generate_samples()</code>, generate a ton of samples at any temperature you like to a file using <code>textgen.generate_to_file()</code>, or incorporate a generated text into a Python script (e.g. a Twitter bot) using <code>textgen.generate(1, return_as_list=True)[0]</code> to store a text as a variable. You can view more of textgenrnn&rsquo;s functions and capabilities in <a href="https://github.com/minimaxir/textgenrnn/blob/master/docs/textgenrnn-demo.ipynb">this demo Jupyter Notebook</a>.</p>
<p>Here&rsquo;s some Shakespeare generated with a 50-minute-trained model at a temperature of 0.5:</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-txt" data-lang="txt"><span class="line"><span class="cl">LUCENTIO:
</span></span><span class="line"><span class="cl">And then shall good grave to my wife thee;
</span></span><span class="line"><span class="cl">Thou would the cause the brieved to me,
</span></span><span class="line"><span class="cl">And let the place and then receives:
</span></span><span class="line"><span class="cl">The rest you the foren to my ways him child,
</span></span><span class="line"><span class="cl">And marry that will be a parties and so set me that be deeds
</span></span><span class="line"><span class="cl">And then the heart and be so shall make the most as he and stand of seat.
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">GLOUCESTER:
</span></span><span class="line"><span class="cl">Your father and madam, or shall for the people
</span></span><span class="line"><span class="cl">And dead to make the truth, or a business
</span></span><span class="line"><span class="cl">As we brother to the place her great the truth;
</span></span><span class="line"><span class="cl">And that which to the smaster and her father,
</span></span><span class="line"><span class="cl">I am I was see the sun have to the royal true.
</span></span></code></pre></div><p>Not too bad, and it&rsquo;s even close to <a href="https://en.wikipedia.org/wiki/Iambic_pentameter">iambic pentameter</a>!</p>
<h2 id="tweaking-the-model">Tweaking the Model</h2>
<p>The most important model configuration options above are <code>rnn_size</code> and <code>rnn_layers</code>: these determine the complexity of the network. Typically, you&rsquo;ll see networks in tutorials be a single 128-cell or 256-cell network. However, textgenrnn&rsquo;s architecture is slightly different as it has an attention layer which incorporates <em>all</em> the preceding model layers. As a result, it&rsquo;s much better to go deeper than wider (e.g. 4x128 is better than 1x512) unless you have a very large amount of text (&gt;10MB). <code>rnn_bidirectional</code> controls whether the recurrent neural network is bidirectional, that is, it processes the previous characters both forward <em>and</em> backward (which works great if text follows specific rules, like Shakespeare&rsquo;s character headings). <code>max_length</code> determines the maximum number of characters for the network to use to predict the next character, which should be increased to let the network learn longer sequences, or decrease for shorter sequences.</p>
<p>Training has a few helpful options as well. <code>num_epochs</code> determines the number of full passes of the data; this can be tweaked if you want to train the model even more. <code>batch_size</code> determines the number of model sequences to train in a step: typically, batch size for deep learning models is 32 or 128, but with a GPU, you can get a speed increase by saturating it with the given 1024 default. <code>train_size</code> determines the proportion of character samples to train; setting it <code>&lt; 1.0</code> both speeds up each epoch, and prevents the model from cheating and being able to learn sequences verbatim. (You can set <code>'validation': True</code> to run the model on the unused data after each epoch to see if the model is overfitting).</p>
<p>Let&rsquo;s try playing with the parameters more on a new text dataset.</p>
<h2 id="word-level-text-generation-with-reddit-data">Word-Level Text Generation With Reddit Data</h2>
<p>You might be asking &ldquo;how do you obtain text data&rdquo;? The popular text-generation use cases like lyric generation and movie scripts are copyright-protected so they&rsquo;re harder to find, and even then, it might not be enough text data to train a new model upon (you typically want atleast 100,000 characters).</p>
<p><a href="https://www.reddit.com">Reddit</a>, however, has <em>millions</em> of submission titles which would be great to train for a model. I wrote a <a href="https://github.com/minimaxir/subreddit-generator">helper script</a> to automatically download the top <em>n</em> Reddit submissions from a given subreddit over a given period of time. If you choose subreddits with similar linguistic styles in their titles, the subreddits will even blend together! Let&rsquo;s play with the Top 20,000 Submissions in 2017 from each of <a href="https://www.reddit.com/r/politics/">/r/politics</a> and <a href="https://www.reddit.com/r/technology/">/r/technology</a>, which results in a 3.3MB file: about 3x as much data as the Shakespeare plays.</p>
<figure>

    <img loading="lazy" srcset="/2018/05/text-neural-networks/reddit_data_hu_283f931b7de2ac8f.webp 320w,/2018/05/text-neural-networks/reddit_data_hu_1a16b2b3e85d1470.webp 768w,/2018/05/text-neural-networks/reddit_data_hu_9bb51cab41ff884e.webp 1024w,/2018/05/text-neural-networks/reddit_data.png 1288w" src="reddit_data.png"/> 
</figure>

<p>One last thing that textgenrnn can do that most char-rnn implementations can&rsquo;t is generate a <em>word level</em> model (thanks to Keras&rsquo;s tokenizers), where the model uses the <em>n</em> previous words/punctuation to predict the next word/punctuation. On the plus side, using only words prevents crazy typoes and since it predicts multiple &ldquo;characters&rdquo; at a time, <code>max_length</code> can be reduced proportionally, dramatically speeding up training. There&rsquo;s two downsides with this approach; since words are all lowercase and punctuation is its own token, the generated text cannot be immediately used without manual editing. Additionally, the model weights will be substantially larger than a character-level model since the word-level model has to store an embedding for each word (up to <code>max_words</code>, which is 10,000 by default when the vocabulary size for a char-level model is 200-300).</p>
<p>Another advantage of the Colaboratory notebook is that you can quickly adjust model parameters, upload a new file, and immediately start training it. We&rsquo;ll set <code>'line_delimited': True</code> and <code>'rnn_bidirectional': False</code> since there aren&rsquo;t specific rules. For word level training, let&rsquo;s set <code>'word_level': True</code> and <code>'max_length': 8</code> to reflect the new training architecture. Since training length has been reduced to 1/5th, we can set <code>'num_epochs': 50</code> and <code>'gen_epoch': 10</code> to balance it out. Rerun the config cell to update parameters, upload the Reddit data file, and rerun training.</p>
<p>The resulting model is much more well trained than the Shakespeare model, and here&rsquo;s a few Reddit submission titles generated at a temperature of 1.0:</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-txt" data-lang="txt"><span class="line"><span class="cl">report : 49 % of americans now believe all of the country ’ s effective
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">people like facebook like it &#39; s 650 of 1 %
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">uber accused of secretly - security popular service ( likely oklahoma )
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">equifax breach fallout : your salary is dead
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">sanders uses texas shooter &#39; s iphone sales
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">adobe videos will be used to sell the web
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">apple to hold cash for $ 500 service
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">fitbit just targeting solar energy
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">george bush &#39; s concept car ‘ goes for all the biggest controversy .
</span></span></code></pre></div><p>Those look pretty good, although they may need a little editing before posting on social media.</p>
<h2 id="followup">Followup</h2>
<p>These examples only train the model for little time as a demo of textgenrnn&rsquo;s fast learning; there&rsquo;s nothing stopping you from increasing <code>num_epochs</code> even more to further refine a model. However, from my experience, the training cell times out after <strong>4 hours</strong>; set <code>num_epochs</code> accordingly, although in my experience that&rsquo;s all you need before the network converges.</p>
<p>In practice, I used this Colaboratory notebook to train <em>many</em> models for <a href="https://www.reddit.com/r/SubredditNN/">/r/SubredditNN</a>, a Reddit subreddit where only text-generating neural network bots trained on other subreddits. And the results are very funny:</p>
<figure>

    <img loading="lazy" srcset="/2018/05/text-neural-networks/subredditnn_hu_14b57f973c80496b.webp 320w,/2018/05/text-neural-networks/subredditnn_hu_c348364252281d3.webp 768w,/2018/05/text-neural-networks/subredditnn_hu_fd7c1bfc6a99ba3e.webp 1024w,/2018/05/text-neural-networks/subredditnn.png 1554w" src="subredditnn.png"/> 
</figure>

<p>Although text generating neural networks aren&rsquo;t at the point where they can <a href="https://www.bloomberg.com/news/features/2018-05-17/i-tried-to-get-an-ai-to-write-this-story-paul-ford">write entire articles by themselves</a>, there are still many opportunities to use it just for fun! And thanks to textgenrnn, it&rsquo;s easy, fast, and cost-effective for anyone to do! Let me know if you make any interesting neural networks with textgenrnn and this Notebook!</p>
]]></content:encoded>
    </item>
    <item>
      <title>Pretrained Character Embeddings for Deep Learning and Automatic Text Generation</title>
      <link>https://minimaxir.com/2017/04/char-embeddings/</link>
      <pubDate>Tue, 04 Apr 2017 06:30:00 -0700</pubDate>
      <guid>https://minimaxir.com/2017/04/char-embeddings/</guid>
      <description>Keras + TensorFlow + Pretrained character embeddings makes text generation a breeze.</description>
      <content:encoded><![CDATA[<p>Deep learning is the biggest, <a href="http://approximatelycorrect.com/2017/03/28/the-ai-misinformation-epidemic/">often misapplied</a> buzzword nowadays for getting pageviews on blogs. As a result, there have been a lot of shenanigans lately with deep learning thought pieces and how deep learning can solve <em>anything</em> and make childhood sci-fi dreams come true.</p>
<p>I&rsquo;m not a fan of <a href="http://tvtropes.org/pmwiki/pmwiki.php/Main/ClarkesThirdLaw">Clarke&rsquo;s Third Law</a>, so I spent some time checking out deep learning myself. As it turns out, with modern deep learning tools like <a href="https://github.com/fchollet/keras">Keras</a>, a higher-level framework on top of the popular <a href="https://www.tensorflow.org">TensorFlow</a> framework, deep learning is <strong>easy to learn and understand</strong>. Yes, easy. And it <em>definitely</em> does not require a PhD, or even a Computer Science undergraduate degree, to implement models or make decisions based on the output.</p>
<p>However, let&rsquo;s try something more expansive than the stereotypical deep learning tutorials.</p>
<h2 id="characters-welcome">Characters Welcome</h2>
<p>Word embeddings have been a popular machine learning trick nowadays. By using an algorithm such as <a href="https://en.wikipedia.org/wiki/Word2vec">Word2vec</a>, you can obtain a numeric representation of a word, and use those values to create numeric representations of higher-level representations like sentences/paragraphs/documents/etc.</p>
<figure>

    <img loading="lazy" srcset="/2017/04/char-embeddings/word-vectors_hu_957d4d31123d1287.webp 320w,/2017/04/char-embeddings/word-vectors_hu_b3c04f84f101893d.webp 768w,/2017/04/char-embeddings/word-vectors_hu_91898a517b108395.webp 1024w,/2017/04/char-embeddings/word-vectors.png 1066w" src="word-vectors.png"/> 
</figure>

<p>However, generating word vectors for datasets can be computationally expensive (see <a href="http://minimaxir.com/2016/08/clickbait-cluster/">my earlier post</a> which uses Apache Spark/Word2vec to create sentence vectors at scale quickly). The academic way to work around this is to use pretrained word embeddings, such as <a href="https://nlp.stanford.edu/projects/glove/">the GloVe vectors</a> collected by researchers at Stanford NLP. However, GloVe vectors are huge; the largest one (840 billion tokens at 300D) is 5.65 GB on disk and may hit issues when loaded into memory on less-powerful computers.</p>
<p>Why not work <em>backwards</em> and calculate <em>character</em> embeddings? Then you could calculate a relatively few amount of vectors which would easily fit into memory, and use those to derive word vectors, which can then be used to derive the sentence/paragraph/document/etc vectors. But training character embeddings traditionally is significantly more computationally expensive since there are 5-6x the amount of tokens, and I don&rsquo;t have access to the supercomputing power of Stanford researchers.</p>
<p>Why not use the <em>existing</em> pretrained word embeddings to extrapolate the corresponding character embeddings within the word? Think &ldquo;<a href="https://en.wikipedia.org/wiki/Bag-of-words_model">bag-of-words</a>,&rdquo; except &ldquo;bag-of-characters.&rdquo; For example, from the embeddings from the word &ldquo;the&rdquo;, we can infer the embeddings for &ldquo;t&rdquo;, &ldquo;h,&rdquo; and &ldquo;e&rdquo; from the parent word, and average the t/h/e vectors from <em>all</em> words/tokens in the dataset corpus. (For this post, I will only look at the 840B/300D dataset since that is the only one with capital letters, which are rather important. If you want to use a dataset with smaller dimensionality, apply <a href="https://en.wikipedia.org/wiki/Principal_component_analysis">PCA</a> on the final results)</p>
<p>I wrote a <a href="https://github.com/minimaxir/char-embeddings/blob/master/create_embeddings.py">simple Python script</a> that takes in the specified pretrained word embeddings and does just that, <a href="https://github.com/minimaxir/char-embeddings/blob/master/glove.840B.300d-char.txt">outputting the character embeddings</a> in the same format. (for simplicity, only ASCII characters are included; the <a href="https://en.wikipedia.org/wiki/Extended_ASCII">extended ASCII characters</a> are intentionally omitted due to compatibility reasons. Additionally, by construction, space and newline characters are not represented in the derived dataset.)</p>
<figure>

    <img loading="lazy" srcset="/2017/04/char-embeddings/char-embeddings_hu_2c1ac4747dfe7721.webp 320w,/2017/04/char-embeddings/char-embeddings.png 674w" src="char-embeddings.png"/> 
</figure>

<p>You may be thinking that I&rsquo;m cheating. So let&rsquo;s set a point-of-reference. Colin Morris <a href="http://colinmorris.github.io/blog/1b-words-char-embeddings">found</a> that when 16D character embeddings from a model used in Google&rsquo;s <a href="https://arxiv.org/abs/1312.3005">One Billion Word Benchmark</a> are projected into a 2D space via t-SNE, patterns emerge: digits are close, lowercase and uppercase letters are often paired, and punctuation marks are loosely paired.</p>
<figure>

    <img loading="lazy" srcset="/2017/04/char-embeddings/tsne_embeddings_hu_24701626db450376.webp 320w,/2017/04/char-embeddings/tsne_embeddings_hu_915d2c16f5761646.webp 768w,/2017/04/char-embeddings/tsne_embeddings.png 795w" src="tsne_embeddings.png"/> 
</figure>

<p>Let&rsquo;s do that for my derived character embeddings, but with <a href="https://www.r-project.org">R</a> and <a href="http://docs.ggplot2.org/current/">ggplot2</a>. t-SNE is <a href="http://distill.pub/2016/misread-tsne/">difficult to use</a> for high-dimensional vectors as combinations of parameters can result in wildly different output, so let&rsquo;s try a couple projections. Here&rsquo;s what happens when my pretrained projections are preprojected from 300D to 16D via <a href="http://ufldl.stanford.edu/tutorial/unsupervised/PCAWhitening/">PCA whitening</a>, and setting perplexity (number of optimal neighbors) to 7.</p>
<figure>

    <img loading="lazy" srcset="/2017/04/char-embeddings/char-tsne_hu_d3036836f3686a8.webp 320w,/2017/04/char-embeddings/char-tsne_hu_b34e859ed7bc3483.webp 768w,/2017/04/char-embeddings/char-tsne_hu_5ad979b9989688d.webp 1024w,/2017/04/char-embeddings/char-tsne.png 1500w" src="char-tsne.png"/> 
</figure>

<p>The algorithm manages to separate and group lowercase, uppercase, and numerals rather distinctly. Quadrupling the dimensionality of the preprocessing step to 64D and changing perplexity to 2 generates a depiction closer to the Google model projection:</p>
<figure>

    <img loading="lazy" srcset="/2017/04/char-embeddings/char-tsne-2_hu_db49ec263903c70d.webp 320w,/2017/04/char-embeddings/char-tsne-2_hu_3ba284cdb5f46ed5.webp 768w,/2017/04/char-embeddings/char-tsne-2_hu_8e5eb2b75768cbe8.webp 1024w,/2017/04/char-embeddings/char-tsne-2.png 1500w" src="char-tsne-2.png"/> 
</figure>

<p>My pretrained character embeddings trick isn&rsquo;t academic, but it&rsquo;s successfully identifying realistic relationships. There might be something here worthwhile.</p>
<h2 id="the-coolness-of-deep-learning">The Coolness of Deep Learning</h2>
<p>Keras, maintained by Google employee <a href="https://twitter.com/fchollet">François Chollet</a>, is so good that it is effectively cheating in the field of machine learning, where even TensorFlow tutorials can be replaced with a single line of code. (which is important for iteration; Keras layers are effectively Lego blocks). A simple read of the <a href="https://github.com/fchollet/keras/tree/master/examples">Keras examples</a> and <a href="https://keras.io/">documentation</a> will let you reverse-engineer most the revolutionary deep learning clickbait thought pieces. Some create entire startups by changing the source dataset of the Keras examples and pitch them to investors none-the-wiser, or make very light wrappers on top the examples for teaching tutorial videos and get thousands of subscribers on YouTube.</p>
<p>I prefer to parse documentation/examples as a proof-of-concept, but never as gospel. Examples are often not the most efficient ways to implement a solution to a problem, just merely a start. In the case of Keras&rsquo;s <a href="https://github.com/fchollet/keras/blob/master/examples/lstm_text_generation.py">text generator example</a>, the initial code was likely modeled after the 2015 blog post <a href="http://karpathy.github.io/2015/05/21/rnn-effectiveness/">The Unreasonable Effectiveness of Recurrent Neural Networks</a> by Andrej Karpathy and the corresponding project <a href="https://github.com/karpathy/char-rnn">char-rnn</a>. There have been many new developments in neural network architecture since 2015 that can improve both speed and performance of the text generation model as a whole.</p>
<h2 id="what-text-to-generate">What Text to Generate?</h2>
<p>The Keras example uses <a href="https://en.wikipedia.org/wiki/Friedrich_Nietzsche">Nietzsche</a> writings as a data source, which I&rsquo;m not fond of because it&rsquo;s difficult to differentiate bad autogenerated Nietzsche rants from actual Nietzsche rants. What I want to generate is text with <em>rules</em>, with the algorithm being judged by how well it follows an inherent structure. My idea is to create <a href="http://magic.wizards.com/en">Magic: The Gathering</a> cards.</p>
<figure>

    <img loading="lazy" srcset="/2017/04/char-embeddings/dragon-whelp_hu_aedb3bfaf354fd82.webp 320w,/2017/04/char-embeddings/dragon-whelp.jpg 464w" src="dragon-whelp.jpg"/> 
</figure>

<p>Inspired by the <a href="https://twitter.com/RoboRosewater">@RoboRosewater</a> Twitter account by Reed Milewicz and the <a href="http://www.mtgsalvation.com/forums/creativity/custom-card-creation/612057-generating-magic-cards-using-deep-recurrent-neural">corresponding research</a> and <a href="https://motherboard.vice.com/en_us/article/the-ai-that-learned-magic-the-gathering">articles</a>, I aim to see if it&rsquo;s possible to recreate the structured design creativity for myself.</p>
<p>Even if you are not familiar with Magic and its rules, you can still find the <a href="https://twitter.com/RoboRosewater/status/756198572282949632">card text</a> of RoboRosewater cards hilarious:</p>
<figure>

    <img loading="lazy" srcset="/2017/04/char-embeddings/horse_hu_1ffb7f6bcc09f430.webp 320w,/2017/04/char-embeddings/horse.jpeg 375w" src="horse.jpeg"/> 
</figure>

<p>Occasionally RoboRosewater, using a weaker model, produces amusing <a href="https://twitter.com/RoboRosewater/status/689184317721960448">neural network trainwrecks</a>:</p>
<figure>

    <img loading="lazy" srcset="/2017/04/char-embeddings/carl_hu_636d0892c1942778.webp 320w,/2017/04/char-embeddings/carl.png 375w" src="carl.png"/> 
</figure>

<p>More importantly, all Magic cards have an explicit structure; they have a name, mana cost in the upper-right, card type, card text, and usually a power and toughness in the bottom-right.</p>
<p>I wrote <a href="https://github.com/minimaxir/char-embeddings/blob/master/create_magic_text.py">another Python script</a> to parse all Magic card data from <a href="https://mtgjson.com">MTG JSON</a> into an encoding which matches this architecture, where each section transition has its own symbol delimiter, along with other encoding simplicities. For example, here is the card <a href="http://gatherer.wizards.com/Pages/Card/Details.aspx?multiverseid=247314">Dragon Whelp</a> in my encoding:</p>
<pre tabindex="0"><code>[Dragon Whelp@{2}{R}{R}#Creature — Dragon$Flying|{R}: ~ gets +1/+0 until end of turn. If this ability has been activated four or more times this turn, sacrifice ~ at the beginning of the next end step.%2^3]
</code></pre><p>These card encodings are all combined into one .txt file, which will be fed into the model.</p>
<h2 id="building-and-training-the-model">Building and Training the Model</h2>
<p>The Keras text generation example operates by breaking a given .txt file into 40-character sequences, and the model tries to predict the 41st character by outputting a probability for each possible character (108 in this dataset). For example, if the input based on the above example is <code>['D', 'r', 'a', 'g', ..., 'D', 'r', 'a', 'g']</code> (with the latter Drag being part of the creature type), the model will optimize for outputting a probability of 1.0 of <code>o</code>; per the <a href="https://en.wikipedia.org/wiki/Cross_entropy#Cross-entropy_error_function_and_logistic_regression">categorical crossentropy</a> loss function, the model is rewarded for assigning correct guesses with 1.0 probability and incorrect guesses with 0.0 probabilities, penalizing half-guesses and wrong guesses.</p>
<p>Each possible 40-character sequence is collected, however only every other third sequence is kept; this prevents the model from being able to learn card text verbatim, plus it also makes training faster. (for this model, there are about <strong>1 million</strong> sequences for the final training). The example uses only a 128-node <a href="https://en.wikipedia.org/wiki/Long_short-term_memory">long-short-term-memory</a> (LSTM) <a href="https://en.wikipedia.org/wiki/Recurrent_neural_network">recurrent neural network</a> (RNN) layer, popular for incorporating a &ldquo;memory&rdquo; into a neural network model, but the example notes at the beginning it can take awhile to train before generated text is coherent.</p>
<p>There are a few optimizations we can make. Instead of supplying the characters directly to the RNN, we can first encode them using an <a href="https://keras.io/layers/embeddings/">Embedding layer</a> so the model can train character context. We can stack more layers on the RNN by adding a 2-level <a href="https://en.wikipedia.org/wiki/Multilayer_perceptron">multilayer perceptron</a>: a <a href="https://www.reddit.com/r/ProgrammerHumor/comments/5si1f0/machine_learning_approaches/">meme</a>, yes, but it helps, as the network must learn latent representations of the data. Thanks to recent developments such as <a href="https://arxiv.org/abs/1502.03167">batch normalization</a> and <a href="https://en.wikipedia.org/wiki/Rectifier_%28neural_networks%29">rectified linear activations</a> for these <a href="https://keras.io/layers/core/#dense">Dense layers</a>, they can both be trained without as much computational overhead, and thanks to Keras, both can be added to a layer with a single line of code each. Lastly, we can add an auxiliary output via Keras&rsquo;s <a href="https://keras.io/models/model/">functional API</a> where the network makes a prediction based on only the output from the RNN in addition to the main output, which forces it to work smarter and ends up resulting in a <em>significant</em> improvement in loss for the main path.</p>
<p>The final architecture ends up looking like this:</p>
<figure>

    <img loading="lazy" srcset="/2017/04/char-embeddings/model_hu_37b43e25201e614b.webp 320w,/2017/04/char-embeddings/model_hu_dddf239d67693047.webp 768w,/2017/04/char-embeddings/model.png 930w" src="model.png"/> 
</figure>

<p>And because we added an Embedding layer, we can load the pretrained 300D character embeds I made earlier, giving the model a good start in understanding character relationships.</p>
<p>The goal of the training is to minimize the total loss of the model. (but for evaluating model performance, we only look at the loss of the main output). The model is trained in <strong>epochs</strong>, where the model sees all the input data atleast once. During each epoch, batches of size 128 are loaded into the model and evaluated, calculating a <strong>batch loss</strong> for each; the gradients from the batch are backpropagated into the previous layers to improve them. While training with Keras, the console reports an <strong>epoch loss</strong>, which is the average of all the batch losses so far in the current epoch, allowing the user to see in real time how the model improves, and it&rsquo;s addicting.</p>
<figure>

    <img loading="lazy" srcset="/2017/04/char-embeddings/keras-training_hu_52c660d870466682.webp 320w,/2017/04/char-embeddings/keras-training_hu_1d55a44e4dbe3025.webp 768w,/2017/04/char-embeddings/keras-training.gif 994w" src="keras-training.gif"/> 
</figure>

<p>Keras/TensorFlow works just fine on the CPU, but for models with a RNN, you&rsquo;ll want to consider using a GPU for performance, specifically one by nVidia. Amazon has cloud GPU instances for $0.90/hr (<a href="http://docs.aws.amazon.com/AWSEC2/latest/UserGuide/Stop_Start.html">not prorated</a>), but very recently, Google announced <a href="https://cloud.google.com/compute/docs/gpus/add-gpus">GPU instances</a> of the same caliber for ~$0.75/hr (prorated to the minute), which is what I used to train this model, although Google Compute Engine requires configuring the GPU drivers first. For 20 epochs, it took about 4 hours and 20 minutes to train the model while spending $3.26, which isn&rsquo;t bad as far as deep learning goes.</p>
<h2 id="making-magic">Making Magic</h2>
<p>After each epoch, the original Keras text generation example takes a sentence from the input data as a seed and predicts the next character in the sequence according to the model, then uses the last 40 characters generated for the next character, etc. The sampling incorporates a diversity/temperature parameter which allows the model to make suboptimal decisions and select characters with lower natural probabilities, which allows for the romantic &ldquo;creativity&rdquo; popular with neural network text generation.</p>
<p>With the Magic card dataset and my tweaked model architecture, generated text is coherent <a href="https://github.com/minimaxir/char-embeddings/blob/master/output/iter-01-0_9204.txt">after the 1st epoch</a>! After about 20 epochs, training becomes super slow, but the predicted text becomes super interesting. Here are a few fun examples from a <a href="https://github.com/minimaxir/char-embeddings/blob/master/output/text_sample.txt">list of hundreds of generated cards</a>. (Note: the power/toughness values at the end of the card have issues; more on that later).</p>
<p>With low diversity, the neural network generated cards that are oddly biased toward card names which include the letter &ldquo;S&rdquo;. The card text also conforms to the rules of the game very well.</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-txt" data-lang="txt"><span class="line"><span class="cl">[Reality Spider@{3}{G}#Creature — Elf Warrior$Whenever ~ deals combat damage to a player, put a +1/+1 counter on it.%^]
</span></span><span class="line"><span class="cl">[Dark Soul@{2}{R}#Instant$~ deals 2 damage to each creature without flying.%^]
</span></span><span class="line"><span class="cl">[Standing Stand@{2}{G}#Creature — Elf Shaman${1}{G}, {T}: Draw a card, then discard a card.%^]
</span></span></code></pre></div><p>In contrast, cards generated with high diversity hit the uncanny valley of coherence and incoherence in both text and game mechanic abuse, which is what makes them interesting.</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-txt" data-lang="txt"><span class="line"><span class="cl">[Portrenline@{2}{R}#Sorcery$As an additional cost to cast ~, exile ~.%^]
</span></span><span class="line"><span class="cl">[Clocidian Lorid@{W}{W}{W}#Instant$Regenerate each creature with flying and each player.%^]
</span></span><span class="line"><span class="cl">[Icomic Convermant@{3}{G}#Sorcery$Search your library for a land card in your graveyard.%1^1]
</span></span></code></pre></div><p>The best-of-both-worlds cards are generated from diversity parameters between both extremes, and often have funny names.</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-txt" data-lang="txt"><span class="line"><span class="cl">[Seal Charm@{W}{W}#Instant$Exile target creature. Its controller loses 1 life.%^]
</span></span><span class="line"><span class="cl">[Shambling Assemblaster@{4}{W}#Creature — Human Cleric$When ~ enters the battlefield, destroy target nonblack creature.%1^1]
</span></span><span class="line"><span class="cl">[Lightning Strength@{3}{R}#Enchantment — Aura$Enchant creature|Enchanted creature gets +3/+3 and has flying, flying, trample, trample, lifelink, protection from black and votile all damage unless you return that card to its owner&#39;s hand.%2^2]
</span></span><span class="line"><span class="cl">[Skysor of Shadows@{7}{B}{B}{B}#Enchantment$As ~ enters the battlefield, choose one —|• Put a -1/-1 counter on target creature.%2^2]
</span></span><span class="line"><span class="cl">[Glinding Stadiers@{4}{W}#Creature — Spirit$Protection from no creatures can&#39;t attack.%^]
</span></span><span class="line"><span class="cl">[Dragon Gault@{3}{G}{U}{U}#Creature — Kraven$~&#39;s power and toughness are 2.%2^2]
</span></span></code></pre></div><p>All Keras/Python code used in this blog post, along with sample Magic card output and the trained model itself, is available open-source <a href="https://github.com/minimaxir/char-embeddings">in this GitHub repository</a>. The repo additionally contains <a href="https://github.com/minimaxir/char-embeddings/blob/master/text_generator_keras_sample.py">a Python script</a> which lets you generate new cards using the model, too!</p>
<h2 id="visualizing-model-performance">Visualizing Model Performance</h2>
<p>One thing deep learning tutorials rarely mention is <em>how</em> to collect the loss data and visualize the change in loss over time. Thanks to Keras&rsquo;s <a href="https://keras.io/callbacks/">utility functions</a>, I wrote a custom model callback which collects the batch losses and epoch losses and writes them to a CSV file.</p>
<p>{% comment %}
In addition to being able to generate images of neural network models as above, Keras has many useful utility functions which I added to the example, such as a callback to save the model while training, and a callback to log the losses to <a href="https://github.com/minimaxir/char-embeddings/blob/master/output/log.csv">a CSV file</a>.
{% endcomment %}</p>
<p>Using R and ggplot2, I can plot the batch loss at every 50th batch to visualize how the model converges over time.</p>
<figure>

    <img loading="lazy" srcset="/2017/04/char-embeddings/batch-losses_hu_797eb59cd8e69be1.webp 320w,/2017/04/char-embeddings/batch-losses_hu_beacae9329be9663.webp 768w,/2017/04/char-embeddings/batch-losses_hu_a8a63ca12c37bf2e.webp 1024w,/2017/04/char-embeddings/batch-losses.png 1200w" src="batch-losses.png"/> 
</figure>

<p>After 20 epochs, the model loss ends up at about <strong>0.30</strong> which is more-than-low-enough for coherent text. As you can see, there are large diminishing returns after a few epochs, which is the hard part of training deep learning models.</p>
<p>Plotting the epoch loss over the batches makes the trend more clear.</p>
<figure>

    <img loading="lazy" srcset="/2017/04/char-embeddings/epoch-losses_hu_91beebdcbcd7000d.webp 320w,/2017/04/char-embeddings/epoch-losses_hu_9e50f22169739719.webp 768w,/2017/04/char-embeddings/epoch-losses_hu_7388ccd5aa56820b.webp 1024w,/2017/04/char-embeddings/epoch-losses.png 1200w" src="epoch-losses.png"/> 
</figure>

<p>In order to prevent early convergence, we can make the model more complex (i.e. stack more layers unironically), but that has trade-offs, both in training <em>and</em> predictive speed, the latter of which is important if using deep learning in a production application.</p>
<p>Lastly, as with the Google One Billion Words benchmark, we can extract the <a href="https://github.com/minimaxir/char-embeddings/blob/master/output/char-embeddings.txt">trained character embeddings</a> from the model (now augmented with Magic card context!) and plot them again to see what has changed.</p>
<figure>

    <img loading="lazy" srcset="/2017/04/char-embeddings/char-tsne-embed_hu_3b7b620e86831bdb.webp 320w,/2017/04/char-embeddings/char-tsne-embed_hu_74a89bea26895dc3.webp 768w,/2017/04/char-embeddings/char-tsne-embed_hu_beed0a100f0bdfea.webp 1024w,/2017/04/char-embeddings/char-tsne-embed.png 1500w" src="char-tsne-embed.png"/> 
</figure>

<p>There are more pairs of uppercase/lowercase characters, although interestingly there isn&rsquo;t much grouping with the special characters added as section breaks in the encoding, or mechanical uppercase characters such as W/U/B/R/G/C/T.</p>
<h2 id="next-steps">Next Steps</h2>
<p>After building the model, I did a little more research to see if others solved the power/toughness problem. Since the sentences are only 40 characters and Magic cards are much longer than 40 characters, it&rsquo;s likely that power/toughness are out-of-scope for the model and it cannot learn their exact values. Turns out that the intended solution is to use a <a href="https://github.com/billzorn/mtgencode">completely different encoding</a>, such as this one for Dragon Whelp:</p>
<div class="highlight"><pre tabindex="0" class="chroma"><code class="language-txt" data-lang="txt"><span class="line"><span class="cl">|5creature|4|6dragon|7|8&amp;^^/&amp;^^^|9flying\{RR}: @ gets +&amp;^/+&amp; until end of turn. if this ability has been activated four or more times this turn, sacrifice @ at the beginning of the next end step.|3{^^RRRR}|0N|1dragon whelp|
</span></span></code></pre></div><p>Power/toughness are generated near the <em>beginning</em> of the card. Sections are delimited by pipes, with a numeral designating the corresponding section. Instead of numerals being used card values, carets are used, which provides a more accurate <em>quantification</em> of values. With this encoding, each character has a <em>singular purpose</em> in the global card context, and their embeddings would likely generate more informative visualizations. (But as a consequence, the generated cards are harder to parse at a glance).</p>
<p>The secondary encoding highlights a potential flaw in my methodology using pretrained character embeddings. Trained machine learning models must be used apples-to-apples on similar datasets; for example, you can&rsquo;t accurately perform Twitter <a href="https://en.wikipedia.org/wiki/Sentiment_analysis">sentiment analysis</a> on a dataset using a model trained on professional movie reviews since Tweets do not follow <a href="https://owl.english.purdue.edu/owl/resource/735/02/">AP Style</a> guidelines. In my case, the <a href="http://commoncrawl.org">Common Crawl</a>, the source of the pretrained embeddings, follows more natural text usage and would not work analogously with the atypical character usages in <em>either</em> of the Magic card encodings.</p>
<p>There&rsquo;s still a <em>lot</em> of work to be done in terms of working with both pretrained character embeddings and improving Magic card generation, but I believe there is promise. The better way to make character embeddings than my script is to do it the hard way and train then manually, maybe even at a higher dimensionality like 500D or 1000D. Likewise, for Magic model building, the <a href="https://github.com/billzorn/mtgencode#training-a-neural-net">mtg-rnn instructions</a> repo uses a large LSTM stacked on a LSTM along with 120/200-character sentences, both of which combined make training <strong>VERY</strong> slow (notably, this was the architecture of the <a href="https://github.com/fchollet/keras/commit/d2b229df2ea0bab712379c418115bc44508bc6f9#diff-904d72bcf9fa38b32f9c1f868ff59367">very first commit</a> for the Keras text generation example, and <a href="https://github.com/fchollet/keras/commit/01d5e7bc4782daafcfa99e035c1bdbe13a985145">was changed</a> to the easily-trainable architecture). There is also promise in a <a href="http://kvfrans.com/variational-autoencoders-explained/">variational autoencoder</a> approach, such as with <a href="https://arxiv.org/abs/1702.02390">textvae</a>.</p>
<p>This work is potentially very expensive and I am strongly considering setting up a <a href="https://www.patreon.com">Patreon</a> in lieu of excess venture capital to subsidize my machine learning/deep learning tasks in the future.</p>
<p>At minimum, working with this example gave me a sufficient application of practical work with Keras, and another tool in my toolbox for data analysis and visualization. Keras makes the model-construction aspect of deep learning trivial and not scary. Hopefully, this article justifies the use of the &ldquo;deep learning&rdquo; buzzword in the headline.</p>
<p>It&rsquo;s also worth mentioning that I actually started working on automatic text generation 6 months ago using a different, non-deep-learning approach, but hit a snag and abandoned that project. With my work on Keras, I found a way around that snag, and on the same Magic dataset with the same input construction, I obtained a model loss of <strong>0.03</strong> at <strong>20% of the cloud computing cost</strong> in about the same amount of time. More on that later.</p>
<hr>
<p><em>The code for generating the R/ggplot2 data visualizations is available in this <a href="http://minimaxir.com/notebooks/char-tsne/">R Notebook</a>, and open-sourced in <a href="https://github.com/minimaxir/char-tsne-visualization">this GitHub Repository.</a></em></p>
<p><em>You are free to use the automatic text generation scripts and data visualizations from this article however you wish, but it would be greatly appreciated if proper attribution is given to this article and/or myself!</em></p>
]]></content:encoded>
    </item>
    <item>
      <title>Predicting And Mapping Arrest Types in San Francisco with LightGBM, R, ggplot2</title>
      <link>https://minimaxir.com/2017/02/predicting-arrests/</link>
      <pubDate>Wed, 08 Feb 2017 06:30:00 -0700</pubDate>
      <guid>https://minimaxir.com/2017/02/predicting-arrests/</guid>
      <description>Given that a SF police arrest occurs at a specified time and place, what is the reason for that arrest?</description>
      <content:encoded><![CDATA[<p>The new hotness in the world of data science is <a href="https://en.wikipedia.org/wiki/Artificial_neural_network">neural networks</a>, which form the basis of <a href="https://en.wikipedia.org/wiki/Deep_learning">deep learning</a>. But while everyone is obsessing about neural networks and how deep learning is <em>magic</em> and can solve <em>any</em> problem if you just <a href="https://www.reddit.com/r/ProgrammerHumor/comments/5si1f0/machine_learning_approaches/">stack enough layers</a>, there have been many recent developments in the relatively nonmagical world of machine learning with <em>boring</em> CPUs.</p>
<p>Years before neural networks were the Swiss army knife of data science, there were <a href="https://en.wikipedia.org/wiki/Gradient_boosting">gradient-boosted machines</a>/<a href="https://en.wikipedia.org/wiki/Gradient_boosting#Gradient_tree_boosting">gradient-boosted trees</a>. GBMs/GBTs are machine learning methods which are effective on many types of data, and do not require the <a href="http://r-statistics.co/Assumptions-of-Linear-Regression.html">traditional model assumptions</a> of linear/logistic regression models. Wikipedia has a good article on the advantages of <a href="https://en.wikipedia.org/wiki/Decision_tree_learning">decision tree learning</a>, and visual diagrams of the architecture:</p>
<figure>

    <img loading="lazy" srcset="/2017/02/predicting-arrests/CART_tree_titanic_survivors_hu_8938562921d14250.webp 320w,/2017/02/predicting-arrests/CART_tree_titanic_survivors.png 360w" src="CART_tree_titanic_survivors.png"/> 
</figure>

<p>GBMs, as <a href="http://scikit-learn.org/stable/modules/generated/sklearn.ensemble.GradientBoostingClassifier.html">implemented</a> in the Python package <a href="http://scikit-learn.org/stable/">scikit-learn</a>, are extremely popular in <a href="https://www.kaggle.com">Kaggle</a> machine learning competitions. But scikit-learn is relatively old, and new technologies have emerged which implement GBMs/GBTs on large datasets with massive parallelization and and in-memory computation. A popular big data machine learning library, <a href="http://www.h2o.ai">H2O</a>, has a <a href="http://docs.h2o.ai/h2o-tutorials/latest-stable/tutorials/gbm-randomforest/index.html">famous GBM implementation</a> which, <a href="https://github.com/szilard/benchm-ml">per benchmarks</a>, is over 10x faster than scikit-learn and is optimized for datasets with millions of records. But even <em>faster</em> than H2O is <a href="https://github.com/dmlc/xgboost">xgboost</a>, which can hit a 5x-10x speed-ups relative to H2O, depending on the dataset size.</p>
<p>Enter <a href="https://github.com/Microsoft/LightGBM">LightGBM</a>, a new (October 2016) open-source machine learning framework by <a href="https://www.microsoft.com/en-us/">Microsoft</a> which, per <a href="https://github.com/Microsoft/LightGBM/issues/211">benchmarks</a> on release, was up to <em>4x faster</em> than xgboost! (xgboost very recently implemented a <a href="https://github.com/dmlc/xgboost/issues/1950">technique</a> also used in LightGBM, which reduced the relative speedup to just ~2x). As a result, LightGBM allows for very efficient model building on large datasets without requiring cloud computing or nVidia CUDA GPUs.</p>
<p>A year ago, I <a href="http://minimaxir.com/2015/12/sf-arrests/">wrote an analysis</a> of the types of police arrests in San Francisco, using data from the <a href="https://data.sfgov.org">SF OpenData</a> initiative, with a <a href="http://minimaxir.com/2015/12/sf-arrest-maps/">followup article</a> analyzing the locations of these arrests. Months later, the same source dataset was used <a href="https://www.kaggle.com/c/sf-crime">for a Kaggle competition</a>. Why not give the dataset another look and test LightGBM out?</p>
<h2 id="playing-with-the-data">Playing With The Data</h2>
<p><em>(You can view the R code used to process the data and generate the data visualizations in <a href="http://minimaxir.com/notebooks/predicting-arrests/">this R Notebook</a>)</em></p>
<p>The <a href="https://data.sfgov.org/Public-Safety/SFPD-Incidents-from-1-January-2003/tmnf-yvry">SFPD Incidents</a> dataset includes crime incidents in San Francisco from 1/1/2003 to 1/17/2017 (at time of analysis). Filtering the dataset only on incidents which resulted in arrests (since most incidents are trivial) leaves a dataset of 634,299 arrests total. The dataset also includes information on the type of crime, the location where the arrest occurred, and the date/time. There are 39 different types of arrests in the <strong>Category</strong> column such as Assault, Burglary, and Prostitution, which serves as the response variable.</p>
<figure>

    <img loading="lazy" srcset="/2017/02/predicting-arrests/data_hu_5459250938b0d49e.webp 320w,/2017/02/predicting-arrests/data_hu_de0cf1293adcafa9.webp 768w,/2017/02/predicting-arrests/data_hu_936b9e66b6422169.webp 1024w,/2017/02/predicting-arrests/data.png 1612w" src="data.png"/> 
</figure>

<p>Meanwhile, we can engineer features from the location and date/time.
Performing an exploratory data analysis of both is helpful to determine at a glance which features may be relevant (fortunately, I did that a year ago).</p>
<p>The location is given as latitude/longitude coordinates, so we can select a longitude <strong>X</strong> and latitude <strong>Y</strong> as features. Date/Time can be deconstructed further. We can extract the <strong>hour</strong> in which a given arrest occurred as a feature (hour can take 24 different values from 0 — 23). Likewise, we can extract the <strong>month</strong> in a similar manner (12 values, from 1 — 12). The <strong>year</strong> the crime occurred can be extracted without special encoding. (2003 — 2017). It is always helpful to include a year feature in predictive models to help account for change over time. The <strong>DayOfWeek</strong> is important, but encoding it as a numeric value is tricker; we logically encode each day of the week from 1 — 7, but which day should be #1? Making Monday #1 and Sunday #7 is the most logical, since a decision tree rule that sets a threshold on DayOfWeek values &gt; 5 will translate logically to a weekend.</p>
<figure>

    <img loading="lazy" srcset="/2017/02/predicting-arrests/predict_matrix_hu_74733e46c443937b.webp 320w,/2017/02/predicting-arrests/predict_matrix.png 716w" src="predict_matrix.png"/> 
</figure>

<p>That&rsquo;s six features total. There are more features which could be helpful, but let&rsquo;s check a baseline model as a start.</p>
<h2 id="modeling">Modeling</h2>
<p>Specifically, the model will predict the answer the question: <em>given that a San Francisco police arrest occurs at a specified time and place, what is the reason for that arrest?</em></p>
<p>For this post, I will use the <a href="https://github.com/Microsoft/LightGBM/tree/master/R-package">R package</a> for LightGBM (which was beta-released in January 2017; it&rsquo;s <em>extremely</em> cutting edge!) We split the dataset 70%/30% into a training set of 444,011 arrests and a test set of 190,288 arrests (due to the large amount of different category labels, the split must be <a href="https://en.wikipedia.org/wiki/Stratified_sampling">stratified</a> to ensure the training and test sets have a balanced distribution of labels; in R, this can be implemented with the <code>caret</code> package and <code>createDataPartition</code>).</p>
<p>LightGBM trains the model on the training set and evaluates it on the test set to minimize the <a href="https://www.kaggle.com/c/sf-crime#evaluation">multiclass logarithmic loss</a> of the model. For now, I use the <a href="https://github.com/Microsoft/LightGBM/blob/master/docs/Parameters.md">default parameters</a> of LightGBM, except to massively increase the number of iterations of the training algorithm, and to stop training the model early if the model stops improving. After about 4 minutes on my laptop (which is very fast for a dataset of this size!), the model returns a multilogloss of <strong>1.98</strong>.</p>
<p>That number sounds arbitrary. Is it good or bad? Let&rsquo;s compare it to the multilogloss from the <a href="https://www.kaggle.com/c/sf-crime/leaderboard">top models</a> from the Kaggle version of the dataset, where a lower score is better:</p>
<figure>

    <img loading="lazy" srcset="/2017/02/predicting-arrests/kaggle_hu_2434549f8c0f370a.webp 320w,/2017/02/predicting-arrests/kaggle.png 521w" src="kaggle.png"/> 
</figure>

<p>&hellip;okay, 1.98 <em>is</em> a good score, and without spending much time adding features to the model and <a href="https://github.com/Microsoft/LightGBM/blob/master/docs/Parameters-tuning.md">tuning parameters</a>! To be fair, my methodology would not necessarily result in the same score on the Kaggle dataset, but it confirms that the LightGBM model is in the top tier of models available for this problem and dataset context. And it didn&rsquo;t <a href="https://www.kaggle.com/smerity/sf-crime/fighting-crime-with-keras/output">require any neural networks</a> either!</p>
<p>There are areas for improvement in feature engineering which <a href="https://www.kaggle.com/c/sf-crime/kernels">other entries</a> in the Kaggle competition implemented, such as a <a href="https://en.wikipedia.org/wiki/Dummy_variable_%28statistics%29">dummy variable</a> indicating whether the offense occurred at an intersection and which SF police station was involved in the arrest. We could also encode features such as hour and DayOfWeek as categorical features (LightGBM conveniently allows this without requiring <a href="https://en.wikipedia.org/wiki/One-hot">one-hot encoding</a> the features) instead of numeric, but in my brief testing, it made the model <em>worse</em>, interestingly.</p>
<h2 id="analyzing-the-lightgbm-model">Analyzing the LightGBM Model</h2>
<p>Another perk of not using a neural network for statistical model building is the ability to learn more about the importance of features in a model, as opposed to it being a <a href="https://en.wikipedia.org/wiki/Black_box">black box</a>. In the case of gradient boosting, we can calculate the proportional contribution of each feature to the total <a href="https://en.wikipedia.org/wiki/Information_gain_in_decision_trees">information gain</a> of the model, which will help identify the most important features, and potentially unhelpful features:</p>
<figure>

    <img loading="lazy" srcset="/2017/02/predicting-arrests/imp_hu_9d9e5109c1b26d0f.webp 320w,/2017/02/predicting-arrests/imp_hu_727e31f0d7218a03.webp 768w,/2017/02/predicting-arrests/imp_hu_19e14c37ffdd5923.webp 1024w,/2017/02/predicting-arrests/imp.png 1200w" src="imp.png"/> 
</figure>

<p>Unsurprisingly, location features are the most important, with both location-based features establishing 70% of the total Gain in the model. But no feature is completely insignificant, which is a good thing.</p>
<p>Back to the multilogloss of 1.98. What does that mean in the real world? What is the <em>accuracy</em> of the model? We run each of the 190,288 arrests in the test set against the model, which returns 39 probability values for each record: one for each possible category of arrest. The category with the highest probability becomes the <strong>predicted</strong> type of arrest.</p>
<figure>

    <img loading="lazy" srcset="/2017/02/predicting-arrests/predicted_results_hu_ef9e6b4cb94c80ab.webp 320w,/2017/02/predicting-arrests/predicted_results.png 676w" src="predicted_results.png"/> 
</figure>

<p>The accuracy of the model on the test set, which is the proportion of predictions where the predicted category value matches the <strong>actual</strong> category value, is <strong>39.7%</strong>, with a 95% confidence interval for the true accuracy between 39.5% and 39.9%. That seems low! However, there is catch-all &ldquo;Other Offenses&rdquo; category for an arrest; if you predicted a &ldquo;Other Offenses&rdquo; label for all the test-set values, you would get an accuracy of <em>31.1%</em>, which serves as the No Information Rate (since it would be the highest accuracy approach if there was no information at all). A 8.6 percentage point improvement is still an improvement though; many industries would <em>love</em> an 8.6 percentage point increase in accuracy, but for this context obviously it&rsquo;s not enough to usher in a <a href="https://en.wikipedia.org/wiki/Minority_Report_%28film%29">Minority Report</a>/<a href="https://en.wikipedia.org/wiki/Person_of_Interest_%28TV_series%29">Person of Interest</a> future.</p>
<p>We can visualize the classifications on the test set by the model using a <a href="https://en.wikipedia.org/wiki/Confusion_matrix">confusion matrix</a>; <code>caret</code> has a simple <code>confusionMatrix()</code> function, and ggplot2 has a <code>geom_tile()</code> to map out the relationships, even with 39 classes. We can also annotate the tiles where actual label = predicted label by drawing a <code>geom_point()</code> on top. Putting it all together:</p>
<figure>

    <img loading="lazy" srcset="/2017/02/predicting-arrests/confusionMatrix_hu_e3738e2da392e9b7.webp 320w,/2017/02/predicting-arrests/confusionMatrix_hu_b9fe669d84787a0b.webp 768w,/2017/02/predicting-arrests/confusionMatrix_hu_d0ad2b163b57b55c.webp 1024w,/2017/02/predicting-arrests/confusionMatrix.png 1500w" src="confusionMatrix.png"/> 
</figure>

<p>There is, indeed, a large amount of confusion. Many of the labels are mispredicted as Other Offenses. Specifically, the model frequently confuses the combinations of Assault, Drug/Narcotics, Larceny/Theft, and Warrants, suggesting that they also may be catch-alls.</p>
<p>In theory, the predicted probabilities from the model between similar types of crime should also be similar, which may be causing these mispredictions. We can calculate the <a href="https://en.wikipedia.org/wiki/Pearson_correlation_coefficient">Pearson correlations</a> between the predicted probabilities, and use <a href="https://en.wikipedia.org/wiki/Hierarchical_clustering">hierarchical clustering</a> to <a href="http://www.sthda.com/english/wiki/ggplot2-quick-correlation-matrix-heatmap-r-software-and-data-visualization">arrange and plot the correlations</a> and their labels in a logical order. The majority of the correlations between labels are between 0 and +/- 0.5 (weak to moderate), but their arrangement tells a different story:</p>
<figure>

    <img loading="lazy" srcset="/2017/02/predicting-arrests/correlationMatrix_hu_2d7aa28c0716951c.webp 320w,/2017/02/predicting-arrests/correlationMatrix_hu_293fd514eec51149.webp 768w,/2017/02/predicting-arrests/correlationMatrix_hu_6340e59358c00dc2.webp 1024w,/2017/02/predicting-arrests/correlationMatrix.png 1500w" src="correlationMatrix.png"/> 
</figure>

<p>From top to bottom, you can see that there is a grouping of more blue-collar, physical crimes types (Assault, Vandalism), then a grouping of less-physical, white-collar crime types (Bribery, Extortion), and then a smaller grouping of seedier crime types (Liquor Laws, Prostitution).</p>
<p>The visualization doesn&rsquo;t necessarily provide more information about the confusion matrix and the mispredictions, but <em>it looks cool</em>, which is enough.</p>
<h2 id="mapping-the-predicted-types-of-arrests">Mapping the Predicted Types of Arrests</h2>
<p>Kaggle competitions emphasize model creation, but don&rsquo;t discuss how to implement and execute models in practice. Since we can predict the type of crime based on the given location and date/time of an arrest, we can map boundaries of the mostly likely type of offense. Using <code>ggmap</code> to get a map of San Francisco, splitting San Francisco into tens of thousands of points, and predicting the most-likely type of arrest at the location with a given date/time.</p>
<p>Let&rsquo;s say we want to predict the types of crime in the future, on April 15th, 2017, during 8 PM. We construct a dataset of those points and the same date/time features used to generate the model originally. Then run those fabricated points through the model again to get new predicted labels (Additionally, we need to remove &ldquo;Other Offenses&rdquo; predicted labels since they cloud up the map). Plotting each point as a <code>geom_tile</code> will interpolate regions around the city. Putting it all together:</p>
<figure>

    <img loading="lazy" srcset="/2017/02/predicting-arrests/crime-2017-04-15-20_hu_a7c72827828f3edb.webp 320w,/2017/02/predicting-arrests/crime-2017-04-15-20_hu_67886ddcfeeb3351.webp 768w,/2017/02/predicting-arrests/crime-2017-04-15-20.png 900w" src="crime-2017-04-15-20.png"/> 
</figure>

<p>Not too shabby. But that&rsquo;s not all; we can <em>animate</em> this map over a day by incrementing the hour, generating a map for each hour (while keeping the colors corresponding to the arrest type consistent), and then <a href="https://github.com/minimaxir/frames-to-gif-osx">stitching the maps together</a> into a GIF. Let&rsquo;s do March 14th, 2017 (<a href="https://en.wikipedia.org/wiki/Pi_Day">Pi Day</a> can be dangerous!) starting at 6 AM:</p>
<figure>

    <img loading="lazy" srcset="/2017/02/predicting-arrests/map_ani_hu_ba104cef91fd9f19.webp 320w,/2017/02/predicting-arrests/map_ani.gif 700w" src="map_ani.gif"/> 
</figure>

<p>Wow!</p>
<h2 id="conclusion">Conclusion</h2>
<p>I deliberately avoided using the term &ldquo;machine learning&rdquo; in the headline of this post because it has been overused to the point of clickbait. Indeed, neural networks/deep learning excel at processing higher-dimensional data such as text, image, and voice data, but in cases where dataset features are <a href="https://news.ycombinator.com/item?id=13563892">simple and known</a>, neural networks are not necessarily the most <em>pragmatic</em> option. CPU/RAM machine learning libraries like LightGBM are still worthwhile, despite the religious fervor for deep learning.</p>
<p>And there&rsquo;s still a lot of work that can be done with the SF Crime Incidents dataset. The model only predicts the type of crime given an arrest occurred; it does not predict <em>if</em> an arrest will occur at a given time and place, which would make a fun project for the future!</p>
<hr>
<p><em>You can view all the R and ggplot2 code used to visualize the San Francisco crime data in <a href="http://minimaxir.com/notebooks/predicting-arrests/">this R Notebook</a>. You can also view the images/data used for this post in <a href="https://github.com/minimaxir/sf-arrests-predict">this GitHub repository</a></em>.</p>
<p><em>You are free to use the data visualizations from this article however you wish, but it would be greatly appreciated if proper attribution is given to this article and/or myself!</em></p>
]]></content:encoded>
    </item>
  </channel>
</rss>
