diff --git a/djblets/datagrid/grids.py b/djblets/datagrid/grids.py
index 1eff36b95b909bdf72d1a93e30d768152b1a47f3..a3883958e46752d917934a121f99c3c156e5d55b 100644
--- a/djblets/datagrid/grids.py
+++ b/djblets/datagrid/grids.py
@@ -30,6 +30,7 @@ from django.conf import settings
 from django.contrib.auth.models import SiteProfileNotAvailable
 from django.core.exceptions import ObjectDoesNotExist
 from django.core.paginator import InvalidPage, QuerySetPaginator
+from django.db.models import ForeignKey
 from django.http import Http404, HttpResponse
 from django.shortcuts import render_to_response
 from django.template.context import RequestContext, Context
@@ -208,6 +209,38 @@ class Column(object):
 
         return s
 
+    def collect_objects(self, object_list):
+        """Iterates through the objects and builds a cache of data to display.
+
+        This optimizes the fetching of data in the grid by grabbing all the
+        IDs of related objects that will be queried for rendering, loading
+        them all at once, and populating the cache.
+        """
+        id_field = '%s_id' % self.field_name
+        ids = set()
+        model = None
+
+        for obj in object_list:
+            if not hasattr(obj, id_field):
+                # This isn't the field type you're looking for.
+                return
+
+            ids.add(getattr(obj, id_field))
+
+            if not model:
+                field = getattr(obj.__class__, self.field_name).field
+
+                try:
+                    model = field.rel.to
+                except AttributeError:
+                    # No idea what this is. Bail.
+                    return
+
+        if model:
+            for obj in model.objects.filter(pk__in=ids):
+                self.data_cache[obj.pk] = obj
+
+
     def render_cell(self, obj):
         """
         Renders the table cell containing column data.
@@ -639,6 +672,9 @@ class DataGrid(object):
             # and it will prevent one query per row.
             object_list = list(self.page.object_list)
 
+        for column in self.columns:
+            column.collect_objects(object_list)
+
         self.rows = [
             {
                 'object': obj,
