Django REST Framework: Dynamically select fields to return or exclude, with nested data.
Why?
GraphQL and REST are two schemes for sharing data across the web that are in hot contest today. One of the most commonly cited benefits of GraphQL is dealing with over-fetching. You can request only the fields you need, so your API call doesn’t need to return a chunky JSON object containing lots of information you don’t care about.
From a REST standpoint, sure, you could add endpoints that return the specific data for your use case. But that also means more code, more tests and more confusion. So is there a cleaner way we can achieve it with Django Rest Framework (DRF)?
The authors of DRF have thought about this, however, a key feature missing in their approach is working with nested data. What if you had a User
object with a foreign key to Company
, but only wanted the name of that company? What if you wanted to exclude certain fields, as opposed to including them?
What this solution offers
This solution offers a subclass of ModelSerializer
that can be inherited in place of it.
This serializer has additional functionality for GET requests, allowing you to specify which fields to include or exclude.
If your serializer inherits from this class, you may add “only_fields” or “exclude_fields” to the query parameters of your GET request to reduce the amount of data returned in the call.
This also works with nested foreign keys, for example:
?only_fields=name,age&company__only_fields=id,name
?only_fields=company,name&company__exclude_fields=name
?exclude_fields=name&company__only_fields=id
?company__exclude_fields=name
Notice for the company
foreign key we type it similarly to Django ORM joins. The query params we are interested in always end withonly_fields
or exclude_fields
.
How it works
This approach builds on the suggestion by the Authors of DRF. Some of the jargon used in the code is highlighted below:
root
: The serializer we are directly querying. In our example, this would beUser
.children
: Any foreign keys or reverse foreign keys onUser
eg:company
dynamic_params
: The query parameters in your GET request, ending withexclude_fields
oronly_fields
.
To achieve this we create an abstract ModelSerializer
class that can be inherited by other classes that inherit from ModelSerializer
, we can call this class the DynamicModelSerializer
.
The first place we need to look is the __init__
method of our new DynamicModelSerializer.
from rest_framework import serializers
class DynamicModelSerializer(serializers.ModelSerializer):
def __init__(self, *args, **kwargs):
request = kwargs.get("context", {}).get("request")
super().__init__(*args, **kwargs)
is_root = bool(request)
if is_root:
if request.method != "GET":
return
dynamic_params = self.get_dynamic_params_for_root(request)
self._context.update({"dynamic_params": dynamic_params})
Firstly we would like to know if the current DynamicSerializer
code we are running is the root
(User) Serializer. One way we can know this is to check if request
exists in the context.
Then we have a two-step approach:
- Get the
dynamic_params
from the request - Set it in the context of the serializer.
Side note: Here we use _context
as opposed to context
, because the context
object has built-in functionality that continues to fetch the parent serializer while it exists, and then returns its _context
. This is not what we want, because we want each serializer to have a different set of dynamic_params
(fields to exclude or include).
Once we’ve set these up, what's left to do is another simple two-step process:
- use the
dynamic_params
to pop the fields we don’t want to return, - set the
dynamic_params
of the child serializers involved in this request.
This can be achieved in the to_representation
method of the DynamicModelSerializer
:
def to_representation(self, *args, **kwargs):
if dynamic_params := self.get_dynamic_params().copy():
self.remove_unwanted_fields(dynamic_params)
self.set_dynamic_params_for_children(dynamic_params)
return super().to_representation(*args, **kwargs)
This is all the logic that's needed for this process, when the child serializer runs, it will also contain this to_representation
method. This would use the dynamic_params
set by its parent to remove unwanted fields and then set the dynamic_params
for its children. Easy peasy.
The nitty gritty
Don’t worry, it's not too nitty or too gritty. But we’ve touched on various functions that are all black boxes in the section above, so let's go over some of what the functions are doing. We highlighted above that the process just has two two-step approaches. Let's have a refresher on those 4 methods.
All of these methods are contained within the DynamicModelSerializer
class.
Only for the root serializer:
self.get_dynamic_params_for_root(request)
self._context.update({"dynamic_params”: dynamic_params})
For the root serializer, and all its descendants:
self.remove_unwanted_fields(dynamic_params)
self.set_dynamic_params_for_children(dynamic_params)
Of the 4 core steps above, step 2 just involves updating a dictionary, so we can leave that out.
get_dynamic_params_for_root(request)
@staticmethod
def is_param_dynamic(p):
return p.endswith("only_fields") or p.endswith("exclude_fields")
def get_dynamic_params_for_root(self, request):
query_params = request.query_params.items()
return {k: v for k, v in query_params if self.is_param_dynamic(k)}
The code above is quite self-explanatory. If your query params end with only_fields
or exclude_fields
, add it to a dictionary and spit it out.
We use endswith
as we know that all dynamic params end with only_fields
or exclude_fields
. Be it company__only_fields=name
or simply exclude_fields=name
.
remove_unwanted_fields(dynamic_params)
def only_keep_fields(self, fields_to_keep):
fields_to_keep = set(fields_to_keep.split(","))
all_fields = set(self.fields.keys())
for field in all_fields - fields_to_keep:
self.fields.pop(field, None)
def exclude_fields(self, fields_to_exclude):
fields_to_exclude = fields_to_exclude.split(",")
for field in fields_to_exclude:
self.fields.pop(field, None)
def remove_unwanted_fields(self, dynamic_params):
if fields_to_keep := dynamic_params.pop("only_fields", None):
self.only_keep_fields(fields_to_keep)
if fields_to_exclude := dynamic_params.pop("exclude_fields", None):
self.exclude_fields(fields_to_exclude)
This code reflects the core of what is done in DRF’s basic implementation of this. Find out the fields to keep or exclude from the dynamic params, and pop the fields you don’t want.
set_dynamic_params_for_children(dynamic_params)
def get_or_create_dynamic_params(self, child):
if "dynamic_params" not in self.fields[child]._context:
self.fields[child]._context.update({"dynamic_params": {}})
return self.fields[child]._context["dynamic_params"]
@staticmethod
def split_param(dynamic_param):
crumbs = dynamic_param.split("__")
return crumbs[0], "__".join(crumbs[1:]) if len(crumbs) > 1 else None
def set_dynamic_params_for_children(self, dynamic_params):
for param, fields in dynamic_params.items():
child, child_dynamic_param = self.split_param(param)
if child in set(self.fields.keys()):
dynamic_params = self.get_or_create_dynamic_params(child)
dynamic_params.update({child_dynamic_param: fields})
What this function does is take the keys of the dynamic params and splits them by __
, the join operator. If the left hand side of the split is the name of a foreign key field in the current serializer, add a key to the _context
of that field (the _context
of the child serializer) containing the right hand of the split, along with the same value.
Note here that self
in the above context is the UserSerializer
class.
Full code example
from rest_framework import serializers
class DynamicModelSerializer(serializers.ModelSerializer):
"""
For use with GET requests, to specify which fields to include or exclude
Mimics some graphql functionality.
Usage: Inherit your ModelSerializer with this class. Add "only_fields" or
"exclude_fields" to the query parameters of your GET request.
This also works with nested foreign keys, for example:
?only_fields=name,age&company__only_fields=id,name
Some more examples:
?only_fields=company,name&company__exclude_fields=name
?exclude_fields=name&company__only_fields=id
?company__exclude_fields=name
Note: the Foreign Key serializer must also inherit from this class
"""
def only_keep_fields(self, fields_to_keep):
fields_to_keep = set(fields_to_keep.split(","))
all_fields = set(self.fields.keys())
for field in all_fields - fields_to_keep:
self.fields.pop(field, None)
def exclude_fields(self, fields_to_exclude):
fields_to_exclude = fields_to_exclude.split(",")
for field in fields_to_exclude:
self.fields.pop(field, None)
def remove_unwanted_fields(self, dynamic_params):
if fields_to_keep := dynamic_params.pop("only_fields", None):
self.only_keep_fields(fields_to_keep)
if fields_to_exclude := dynamic_params.pop("exclude_fields", None):
self.exclude_fields(fields_to_exclude)
def get_or_create_dynamic_params(self, child):
if "dynamic_params" not in self.fields[child]._context:
self.fields[child]._context.update({"dynamic_params": {}})
return self.fields[child]._context["dynamic_params"]
@staticmethod
def split_param(dynamic_param):
crumbs = dynamic_param.split("__")
return crumbs[0], "__".join(crumbs[1:]) if len(crumbs) > 1 else None
def set_dynamic_params_for_children(self, dynamic_params):
for param, fields in dynamic_params.items():
child, child_dynamic_param = self.split_param(param)
if child in set(self.fields.keys()):
dynamic_params = self.get_or_create_dynamic_params(child)
dynamic_params.update({child_dynamic_param: fields})
@staticmethod
def is_param_dynamic(p):
return p.endswith("only_fields") or p.endswith("exclude_fields")
def get_dynamic_params_for_root(self, request):
query_params = request.query_params.items()
return {k: v for k, v in query_params if self.is_param_dynamic(k)}
def get_dynamic_params(self):
"""
When dynamic params get passed down in set_context_for_children
If the child is a subclass of ListSerializer (has many=True)
The context must be fetched from ListSerializer Class
"""
if isinstance(self.parent, serializers.ListSerializer):
return self.parent._context.get("dynamic_params", {})
return self._context.get("dynamic_params", {})
def __init__(self, *args, **kwargs):
request = kwargs.get("context", {}).get("request")
super().__init__(*args, **kwargs)
is_root = bool(request)
if is_root:
if request.method != "GET":
return
dynamic_params = self.get_dynamic_params_for_root(request)
self._context.update({"dynamic_params": dynamic_params})
def to_representation(self, *args, **kwargs):
if dynamic_params := self.get_dynamic_params().copy():
self.remove_unwanted_fields(dynamic_params)
self.set_dynamic_params_for_children(dynamic_params)
return super().to_representation(*args, **kwargs)
class Meta:
abstract = True
from myapp.serializers import DynamicModelSerializer
from myapp.models import Company, User
class CompanySerializer(DynamicModelSerializer):
class Meta:
model = Company
fields = ("id", "name")
class UserSerializer(DynamicModelSerializer):
company = CompanySerializer()
class Meta:
model = User
fields = ("id", "name", "company")
One last thing: for this to work, all the serializers that are involved in the request need to inherit from DynamicModelSerializer.
Thank you for reading, and happy coding!