diff --git a/include/nlohmann/detail/input/json_sax.hpp b/include/nlohmann/detail/input/json_sax.hpp
index 38a0a710..d0161cd4 100644
--- a/include/nlohmann/detail/input/json_sax.hpp
+++ b/include/nlohmann/detail/input/json_sax.hpp
@@ -176,7 +176,8 @@ class json_sax_dom_parser : public json_sax<BasicJsonType>
 
     bool key(std::string&& val) override
     {
-        last_key = val;
+        // add null at given key and store the reference for later
+        object_element = &(ref_stack.back()->m_value.object->operator[](val));
         return true;
     }
 
@@ -219,8 +220,8 @@ class json_sax_dom_parser : public json_sax<BasicJsonType>
     BasicJsonType root;
     /// stack to model hierarchy of values
     std::vector<BasicJsonType*> ref_stack;
-    /// helper variable for object keys
-    std::string last_key;
+    /// helper to hold the reference for the next object element
+    BasicJsonType* object_element = nullptr;
 
     /*!
     @invariant If the ref stack is empty, then the passed value will be the new
@@ -247,8 +248,9 @@ class json_sax_dom_parser : public json_sax<BasicJsonType>
             }
             else
             {
-                BasicJsonType& r = ref_stack.back()->m_value.object->operator[](last_key) = BasicJsonType(std::forward<Value>(v));
-                return &r;
+                assert(object_element);
+                *object_element = BasicJsonType(std::forward<Value>(v));
+                return object_element;
             }
         }
     }
diff --git a/single_include/nlohmann/json.hpp b/single_include/nlohmann/json.hpp
index 009f1109..3b1f7e43 100644
--- a/single_include/nlohmann/json.hpp
+++ b/single_include/nlohmann/json.hpp
@@ -3310,7 +3310,8 @@ class json_sax_dom_parser : public json_sax<BasicJsonType>
 
     bool key(std::string&& val) override
     {
-        last_key = val;
+        // add null at given key and store the reference for later
+        object_element = &(ref_stack.back()->m_value.object->operator[](val));
         return true;
     }
 
@@ -3353,8 +3354,8 @@ class json_sax_dom_parser : public json_sax<BasicJsonType>
     BasicJsonType root;
     /// stack to model hierarchy of values
     std::vector<BasicJsonType*> ref_stack;
-    /// helper variable for object keys
-    std::string last_key;
+    /// helper to hold the reference for the next object element
+    BasicJsonType* object_element = nullptr;
 
     /*!
     @invariant If the ref stack is empty, then the passed value will be the new
@@ -3381,8 +3382,9 @@ class json_sax_dom_parser : public json_sax<BasicJsonType>
             }
             else
             {
-                BasicJsonType& r = ref_stack.back()->m_value.object->operator[](last_key) = BasicJsonType(std::forward<Value>(v));
-                return &r;
+                assert(object_element);
+                *object_element = BasicJsonType(std::forward<Value>(v));
+                return object_element;
             }
         }
     }