diff --git a/djblets/db/fields/relation_counter_field.py b/djblets/db/fields/relation_counter_field.py
index e33d18489b2e61d2610297bdc8bbb947c4cd83bc..f66167af1126c62035305a732353cbe16176a900 100644
--- a/djblets/db/fields/relation_counter_field.py
+++ b/djblets/db/fields/relation_counter_field.py
@@ -13,12 +13,12 @@ from django.utils import six
 from djblets.db.fields.counter_field import CounterField
 
 
-class InstanceState(object):
+class InstanceState(weakref.ref):
     """Tracks state for a RelationCounterField instance assocation.
 
     State instances are bound to the lifecycle of a model instance.
-    They keep track of the model instance (using a weak reference) and
-    all RelationCounterFields tied to the relation name provided.
+    They're a type of weak reference for model instances that contain
+    additional state needed for the tracking and update process.
 
     These are used for looking up the proper instance and
     RelationCounterFields on the other end of a reverse relation, given
@@ -29,7 +29,7 @@ class InstanceState(object):
 
     Instance states can either represent saved instances or unsaved
     instances. Unsaved instance states represent instances that haven't yet
-    been saved to the database (with a primary key of None). While saved
+    been saved to the database (with a primary key of ``None``). While saved
     instance states exist per-instance/relation name, there's only one
     unsaved instance state per instance.
 
@@ -39,35 +39,35 @@ class InstanceState(object):
     discarded.
     """
 
-    def setup(self, model_instance):
+    def __init__(self, model_instance):
         """Set up the state.
 
         Args:
             model_instance (django.db.models.Model):
                 The model instance that this state tracks.
         """
-        self.fields = set()
+        super(InstanceState, self).__init__(model_instance)
+
         self.to_clear = set()
-        self.dispatch_uid = '%s.%s:%s' % (self.__class__.__module__,
-                                          self.__class__.__name__,
-                                          id(model_instance))
-        self.model_instance_ref = weakref.ref(model_instance,
-                                              self._on_instance_destroyed)
-        self.model_cls = model_instance.__class__
-        self.model_instance_id = id(model_instance)
-
-        if model_instance.pk is None:
-            post_save.connect(self._on_instance_first_save,
-                              sender=self.model_cls,
-                              dispatch_uid=self.dispatch_uid)
-        else:
-            pre_delete.connect(self._on_instance_pre_delete,
-                               sender=self.model_cls,
-                               dispatch_uid=self.dispatch_uid)
+        self.field_names = set()
+        self._model_cls = type(model_instance)
 
     @property
     def model_instance(self):
-        return self.model_instance_ref()
+        """The model instance being tracked.
+
+        This will be ``None`` if the instance has been destroyed.
+        """
+        return self()
+
+    def track_field(self, field):
+        """Track information on a field referencing this state.
+
+        Args:
+            field (django.db.models.Field):
+                The field to track.
+        """
+        self.field_names.add(field.attname)
 
     def __repr__(self):
         """Return a string representation of the instance state.
@@ -85,95 +85,6 @@ class InstanceState(object):
         else:
             return '<InstanceState for %r (destroyed)>' % self.model_cls
 
-    def _on_instance_first_save(self, instance, created=False, **kwargs):
-        """Handler for the first save on a newly created instance.
-
-        This will reset information on this instance, removing this
-        existing state, and will then add new instance states for each
-        field relation.
-
-        Args:
-            instance (django.db.models.Model):
-                The model instance being saved.
-
-            created (bool):
-                Whether the object was created. This must always be
-                true for this handler.
-
-            **kwargs (dict):
-                Extra keyword arguments passed to the handler.
-        """
-        model_instance = self.model_instance
-
-        if model_instance is None or instance is not model_instance:
-            return
-
-        assert created
-        assert instance.pk is not None
-
-        instance_cls = instance.__class__
-
-        # Stop listening immediately for any new signals here.
-        # The Signal stuff deals with thread locks, so we shouldn't
-        # have to worry about reaching any of this twice.
-        post_save.disconnect(sender=instance_cls,
-                             dispatch_uid=self.dispatch_uid)
-
-        # This is a new row in the database (that is, the model instance
-        # has been saved for the very first time), we need to flush any
-        # existing state. This will ensure the unsaved version of this
-        # state does not remain.
-        RelationCounterField._cleanup_state(instance_cls=instance_cls,
-                                            instance_pk=instance.pk,
-                                            instance_id=id(instance))
-
-        # Now we can register each RelationCounterField on here.
-        for field in instance_cls._meta.local_fields:
-            if isinstance(field, RelationCounterField):
-                RelationCounterField._store_state(instance, field)
-
-    def _on_instance_destroyed(self, *args):
-        """Handler for when the instance is destroyed.
-
-        This will remove all state related to the instance. That will
-        result in the state object being destroyed.
-
-        Args:
-            *args (tuple, unused):
-                Arguments passed to the callback.
-        """
-        try:
-            RelationCounterField._cleanup_state(
-                instance_cls=self.model_cls,
-                instance_pk=None,
-                instance_id=self.model_instance_id)
-        except AttributeError:
-            # Ignore any attribute errors when this fails. It is most
-            # likely occurring while a thread/process is shutting down,
-            # and some state no longer exists. We've seen this manifest
-            # as two separate AttributeErrors so far.
-            pass
-
-    def _on_instance_pre_delete(self, instance, **kwargs):
-        """Handler for when an instance is about to be deleted.
-
-        This will reset the state of the instance, unregistering it from
-        lists, and removing any pending signal connections.
-
-        Args:
-            instance (django.db.models.Model):
-                The instance being deleted.
-        """
-        model_instance = self.model_instance
-
-        if model_instance is not None and instance is model_instance:
-            RelationCounterField._cleanup_state(
-                instance_cls=instance.__class__,
-                instance_pk=instance.pk,
-                instance_id=id(instance))
-
-        pre_delete.disconnect(sender=self.model_cls,
-                              dispatch_uid=self.dispatch_uid)
 
 class RelationTracker(object):
     """Tracks relations and updates state for all affected CounterFields.
@@ -361,10 +272,7 @@ class RelationTracker(object):
         model_instance = main_state.model_instance
 
         if model_instance is not None:
-            yield (
-                model_instance,
-                [field.attname for field in main_state.fields]
-            )
+            yield model_instance, main_state.field_names
 
             if len(states) > 1:
                 self._sync_fields_from_main_state(main_state, states[1:])
@@ -544,9 +452,9 @@ class RelationTracker(object):
                 other_instance = other_state.model_instance
 
                 if other_instance is not None:
-                    for field in other_state.fields:
-                        setattr(other_instance, field.attname,
-                                getattr(main_instance, field.attname))
+                    for field_name in other_state.field_names:
+                        setattr(other_instance, field_name,
+                                getattr(main_instance, field_name))
         else:
             # The instance fell out of scope. We'll have to just reload
             # all the other instances. This should be rare.
@@ -657,6 +565,27 @@ class RelationCounterField(CounterField):
     # end up blocking.
     _state_lock = threading.RLock()
 
+    # Flag for determining if global signal handlers need to be set up.
+    _signals_setup = False
+
+    @classmethod
+    def has_tracked_states(cls):
+        """Return whether there are currently any states being tracked.
+
+        This will begin by cleaning up any expired states whose instances
+        have been destroyed, if there are any. Then it will check if there
+        are any remaining states still being tracked and return a result.
+
+        Returns:
+            bool:
+            ``True`` if there are any states still being tracked.
+            ``False`` if not.
+        """
+        cls._cleanup_state()
+
+        return (bool(cls._saved_instance_states) or
+                bool(cls._unsaved_instance_states))
+
     @classmethod
     def _cleanup_state(cls, instance_cls=None, instance_pk=None,
                        instance_id=None):
@@ -758,36 +687,41 @@ class RelationCounterField(CounterField):
         that points to the :py:class:`InstanceState`, keeping the state's
         reference alive as long as the instance is alive.
         """
+        cls._cleanup_state()
+
         with cls._state_lock:
             if instance.pk is None:
                 states = cls._unsaved_instance_states
             else:
-                states = cls._saved_instance_states
-                key = (instance.__class__, instance.pk, field._rel_field_name)
-
-                try:
-                    states = cls._saved_instance_states[key]
-                except KeyError:
-                    states = weakref.WeakValueDictionary()
-                    cls._saved_instance_states[key] = states
+                main_key = (type(instance), instance.pk, field._rel_field_name)
+                states = cls._saved_instance_states.setdefault(main_key, {})
 
             key = id(instance)
 
             try:
                 state = states[key]
-                state_is_new = False
             except KeyError:
-                state = InstanceState()
-                state_is_new = True
+                state = InstanceState(instance)
                 states[key] = state
 
-            if instance.pk is not None:
-                setattr(instance, '_%s_state' % field.attname, state)
+            state.track_field(field)
 
-        if state_is_new:
-            state.setup(instance)
+            # Mark that this instance tracks RelationCounterField states,
+            # so our signal handlers have something they can easily look for.
+            instance._tracks_relcounterfield_states = True
 
-        state.fields.add(field)
+            if instance.pk is not None:
+                # Attach the state to the field. This, along with being stored
+                # in a WeakValueDictionary above, ensures that, so long as the
+                # field remains in memory, the state will remain as well, and
+                # that the state will be removed once the field instance
+                # disappears (which will happen when the model instance
+                # disappears).
+                #
+                # There should never be more than one state attached for a
+                # field, since store_state() is only called once per model
+                # instance per field, and field names are unique.
+                setattr(instance, '_%s_instance_state' % field.attname, state)
 
     @classmethod
     def _get_saved_states(cls, model_cls, instance_pk, rel_field_name):
@@ -829,6 +763,103 @@ class RelationCounterField(CounterField):
 
         return None
 
+    @classmethod
+    def _on_instance_first_save(cls, instance, created, **kwargs):
+        """Handler for the first save on a newly created instance.
+
+        This will reset information on this instance, removing this
+        existing state, and will then add new instance states for each
+        field relation.
+
+        Args:
+            instance (django.db.models.Model):
+                The model instance being saved.
+
+            created (bool):
+                Whether the object was created. This must always be
+                true for this handler.
+
+            **kwargs (dict):
+                Extra keyword arguments passed to the handler.
+
+        Returns:
+            bool:
+            ``True`` if this instance was handled. ``False`` if it was ignored.
+        """
+        assert instance is not None
+        assert instance.pk is not None
+
+        if (not created or
+            not getattr(instance, '_tracks_relcounterfield_states', False)):
+            # This isn't an instance we're tracking. Ignore it.
+            return False
+
+        instance_id = id(instance)
+
+        try:
+            state = cls._unsaved_instance_states[instance_id]
+        except KeyError:
+            # This isn't a tracked unsaved instance. We can skip it.
+            return
+
+        model_instance = state.model_instance
+
+        if model_instance is None:
+            # The references dropped. We're no longer working with this state.
+            return False
+
+        assert instance is model_instance
+
+        with cls._state_lock:
+            # Remove the old state information from the dictionary. We could
+            # do this with a .pop(key, None) above, but we don't want to lock
+            # unnecessarily.
+            #
+            # Even though we just fetched it, we can't assume that the data
+            # is still there, since another thread could have done something
+            # with it, so we can't do a plain del().
+            cls._unsaved_instance_states.pop(instance_id, None)
+
+        # Now we can register each RelationCounterField on here.
+        for field in type(instance)._meta.local_fields:
+            if isinstance(field, cls):
+                cls._store_state(instance, field)
+
+        return True
+
+    @classmethod
+    def _on_instance_pre_delete(cls, instance, **kwargs):
+        """Handler for when an instance is about to be deleted.
+
+        This will reset the state of the instance, unregistering it from
+        lists, and removing any pending signal connections.
+
+        Args:
+            instance (django.db.models.Model):
+                The instance being deleted.
+
+            **kwargs (dict):
+                Extra keyword arguments passed to the handler.
+
+        Returns:
+            bool:
+            ``True`` if this instance was handled. ``False`` if it was ignored.
+        """
+        assert instance is not None
+
+        if not getattr(instance, '_tracks_relcounterfield_states', False):
+            # This isn't an instance we're tracking. Ignore it.
+            return False
+
+        instance_id = id(instance)
+        assert instance_id not in cls._unsaved_instance_states
+
+        cls._cleanup_state(instance_cls=type(instance),
+                           instance_pk=instance.pk,
+                           instance_id=instance_id)
+
+        return True
+
     def __init__(self, rel_field_name=None, *args, **kwargs):
         def _initializer(model_instance):
             if model_instance.pk:
@@ -856,17 +887,32 @@ class RelationCounterField(CounterField):
         """
         super(RelationCounterField, self)._do_post_init(instance)
 
-        RelationCounterField._store_state(instance, self)
+        cls = type(self)
+
+        # The first time we reach this, we want to let the class itself begin
+        # listening for signals for instance saves and deletes.
+        #
+        # Note that, in theory, this could end up being called twice, but
+        # that's ultimately okay, since we'll be using the same dispatch UID
+        # on all signal connections, and the signals will ensure in a
+        # thread-safe way that there's only one callback per dispatch UID.
+        if not cls._signals_setup:
+            dispatch_uid = '%s.%s' % (cls.__module__, cls.__name__)
+            post_save.connect(cls._on_instance_first_save,
+                              dispatch_uid=dispatch_uid)
+            pre_delete.connect(cls._on_instance_pre_delete,
+                               dispatch_uid=dispatch_uid)
+            cls._signals_setup = True
+
+        cls._store_state(instance, self)
 
         if not self._relation_tracker:
-            instance_cls = instance.__class__
+            instance_cls = type(instance)
             key = (instance_cls, self._rel_field_name)
 
             try:
-                self._relation_tracker = \
-                    RelationCounterField._relation_trackers[key]
+                self._relation_tracker = cls._relation_trackers[key]
             except KeyError:
-                self._relation_tracker = \
-                    RelationTracker(instance_cls, self._rel_field_name)
-                RelationCounterField._relation_trackers[key] = \
-                    self._relation_tracker
+                self._relation_tracker = RelationTracker(instance_cls,
+                                                         self._rel_field_name)
+                cls._relation_trackers[key] = self._relation_tracker
diff --git a/djblets/db/tests/test_relation_counter_field.py b/djblets/db/tests/test_relation_counter_field.py
index 3db4a9c027c78ae6c0603ff86a7833c148aa9bd9..97e25a45716ab42080a0458e2f7dec081a399963 100644
--- a/djblets/db/tests/test_relation_counter_field.py
+++ b/djblets/db/tests/test_relation_counter_field.py
@@ -1,9 +1,12 @@
 from __future__ import unicode_literals
 
+import gc
+
 import django
 import nose
 from django.db import models, transaction
 from django.db.models.signals import post_save, pre_delete
+from kgb import SpyAgency
 
 from djblets.db.fields import RelationCounterField
 from djblets.testing.testcases import TestCase, TestModelsLoaderMixin
@@ -55,7 +58,7 @@ class BadKeyRefModel(models.Model):
         return False
 
 
-class RelationCounterFieldTests(TestModelsLoaderMixin, TestCase):
+class RelationCounterFieldTests(SpyAgency, TestModelsLoaderMixin, TestCase):
     """Tests for djblets.db.fields.RelationCounterField."""
     tests_app = 'djblets.db.tests'
 
@@ -97,16 +100,16 @@ class RelationCounterFieldTests(TestModelsLoaderMixin, TestCase):
 
         # Make sure the state is clear due to dropped references before
         # each run.
-        self.assertFalse(RelationCounterField._saved_instance_states)
-        self.assertFalse(RelationCounterField._unsaved_instance_states)
+        gc.collect()
+        self.assertFalse(RelationCounterField.has_tracked_states())
 
     def tearDown(self):
         super(RelationCounterFieldTests, self).tearDown()
 
         # Make sure the state is clear due to dropped references after
         # each run.
-        self.assertFalse(RelationCounterField._saved_instance_states)
-        self.assertFalse(RelationCounterField._unsaved_instance_states)
+        gc.collect()
+        self.assertFalse(RelationCounterField.has_tracked_states())
 
     #
     # Instance tracking tests
@@ -176,6 +179,8 @@ class RelationCounterFieldTests(TestModelsLoaderMixin, TestCase):
 
         # Adding a reference again should result in these counters being
         # increased, rather than the old ones.
+        assert reffed.pk
+        assert new_model.pk
         new_model.m2m.add(reffed)
         self.assertEqual(new_model.counter, 1)
         self.assertEqual(new_model.counter_2, 1)
@@ -186,33 +191,31 @@ class RelationCounterFieldTests(TestModelsLoaderMixin, TestCase):
         self.assertEqual(model.counter, 1)
         self.assertEqual(model.counter_2, 1)
 
-    def test_signals_with_create(self):
-        """Testing RelationCounterField signal management with newly-created,
-        unsaved instance
+    def test_save_calls_on_instance_first_save(self):
+        """Testing RelationCounterField._on_instance_first_save called on
+        first model.save()
         """
-        model = M2MRefModel()
-        self.assertIsNone(model.pk)
-
-        self.assertEqual(len(pre_delete._live_receivers(M2MRefModel)), 0)
-        self.assertEqual(len(post_save._live_receivers(M2MRefModel)), 1)
+        self.spy_on(RelationCounterField._on_instance_first_save)
+        save_func = RelationCounterField._on_instance_first_save
 
-    def test_signals_with_save(self):
-        """Testing RelationCounterField signal management after first instance
-        save
-        """
         model = M2MRefModel()
         model.save()
 
-        self.assertEqual(len(pre_delete._live_receivers(M2MRefModel)), 1)
-        self.assertEqual(len(post_save._live_receivers(M2MRefModel)), 0)
+        self.assertTrue(save_func.last_called_with(model, created=True))
+        self.assertTrue(save_func.last_returned(True))
+
+    def test_delete_calls_on_instance_pre_delete(self):
+        """Testing RelationCounterField._on_instance_pre_delete called on
+        model.delete()
+        """
+        self.spy_on(RelationCounterField._on_instance_pre_delete)
+        delete_func = RelationCounterField._on_instance_pre_delete
 
-    def test_signals_with_delete(self):
-        """Testing RelationCounterField signal management after delete"""
         model = M2MRefModel.objects.create()
         model.delete()
 
-        self.assertEqual(len(pre_delete._live_receivers(M2MRefModel)), 0)
-        self.assertEqual(len(post_save._live_receivers(M2MRefModel)), 0)
+        self.assertTrue(delete_func.last_called_with(instance=model))
+        self.assertTrue(delete_func.last_returned(True))
 
     def test_unsaved_and_other_double_save(self):
         """Testing RelationCounterField with an unsaved object and a double
@@ -224,30 +227,27 @@ class RelationCounterFieldTests(TestModelsLoaderMixin, TestCase):
         # signal connection from the first stuck around and saw that
         # updated=False, which it expected would be True. However, it didn't
         # check first if it was matching the expected instance.
+        self.spy_on(RelationCounterField._on_instance_first_save)
+        save_func = RelationCounterField._on_instance_first_save
+
         model1 = M2MRefModel()
         model2 = M2MRefModel()
-        self.assertEqual(model1.pk, None)
-        self.assertEqual(model2.pk, None)
-        self.assertEqual(len(post_save._live_receivers(M2MRefModel)), 2)
+        self.assertIsNone(model1.pk)
+        self.assertIsNone(model2.pk)
+
+        self.assertEqual(len(save_func.calls), 0)
 
         # Perform the first save, which will do update=True.
         model2.save()
-        self.assertEqual(len(post_save._live_receivers(M2MRefModel)), 1)
+        self.assertEqual(len(save_func.calls), 1)
+        self.assertTrue(save_func.last_called_with(model2, created=True))
+        self.assertTrue(save_func.last_returned(True))
 
         # Perform the second save, which will do update=False.
         model2.save()
-        self.assertEqual(len(post_save._live_receivers(M2MRefModel)), 1)
-
-    def test_disconnect_signal_on_destroy(self):
-        """Testing RelationCounterField disconnects signals for an object when
-        it falls out of scope
-        """
-        model = M2MRefModel()
-        self.assertEqual(model.pk, None)
-        self.assertEqual(len(post_save._live_receivers(M2MRefModel)), 1)
-
-        model = None
-        self.assertEqual(len(post_save._live_receivers(M2MRefModel)), 0)
+        self.assertEqual(len(save_func.calls), 2)
+        self.assertTrue(save_func.last_called_with(model2, created=False))
+        self.assertTrue(save_func.last_returned(False))
 
 
     #
