Arthur de Jong

Open Source / Free Software developer

summaryrefslogtreecommitdiffstats
path: root/django/db/models/fields/related_lookups.py
blob: d3e4f8fa8a3434ff5d16ed0cdd3cfe2b9d5b1308 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from django.db.models.lookups import (
    Exact, GreaterThan, GreaterThanOrEqual, In, LessThan, LessThanOrEqual,
)


class MultiColSource(object):
    contains_aggregate = False

    def __init__(self, alias, targets, sources, field):
        self.targets, self.sources, self.field, self.alias = targets, sources, field, alias
        self.output_field = self.field

    def __repr__(self):
        return "{}({}, {})".format(
            self.__class__.__name__, self.alias, self.field)

    def relabeled_clone(self, relabels):
        return self.__class__(relabels.get(self.alias, self.alias),
                              self.targets, self.sources, self.field)


def get_normalized_value(value, lhs):
    from django.db.models import Model
    if isinstance(value, Model):
        value_list = []
        # A case like Restaurant.objects.filter(place=restaurant_instance),
        # where place is a OneToOneField and the primary key of Restaurant.
        if getattr(lhs.output_field, 'primary_key', False):
            return (value.pk,)
        sources = lhs.output_field.get_path_info()[-1].target_fields
        for source in sources:
            while not isinstance(value, source.model) and source.remote_field:
                source = source.remote_field.model._meta.get_field(source.remote_field.field_name)
            value_list.append(getattr(value, source.attname))
        return tuple(value_list)
    if not isinstance(value, tuple):
        return (value,)
    return value


class RelatedIn(In):
    def get_prep_lookup(self):
        if not isinstance(self.lhs, MultiColSource) and self.rhs_is_direct_value():
            # If we get here, we are dealing with single-column relations.
            self.rhs = [get_normalized_value(val, self.lhs)[0] for val in self.rhs]
            # We need to run the related field's get_prep_lookup(). Consider case
            # ForeignKey to IntegerField given value 'abc'. The ForeignKey itself
            # doesn't have validation for non-integers, so we must run validation
            # using the target field.
            if hasattr(self.lhs.output_field, 'get_path_info'):
                # Run the target field's get_prep_lookup. We can safely assume there is
                # only one as we don't get to the direct value branch otherwise.
                self.rhs = self.lhs.output_field.get_path_info()[-1].target_fields[-1].get_prep_lookup(
                    self.lookup_name, self.rhs)
        return super(RelatedIn, self).get_prep_lookup()

    def as_sql(self, compiler, connection):
        if isinstance(self.lhs, MultiColSource):
            # For multicolumn lookups we need to build a multicolumn where clause.
            # This clause is either a SubqueryConstraint (for values that need to be compiled to
            # SQL) or a OR-combined list of (col1 = val1 AND col2 = val2 AND ...) clauses.
            from django.db.models.sql.where import WhereNode, SubqueryConstraint, AND, OR

            root_constraint = WhereNode(connector=OR)
            if self.rhs_is_direct_value():
                values = [get_normalized_value(value, self.lhs) for value in self.rhs]
                for value in values:
                    value_constraint = WhereNode()
                    for source, target, val in zip(self.lhs.sources, self.lhs.targets, value):
                        lookup_class = target.get_lookup('exact')
                        lookup = lookup_class(target.get_col(self.lhs.alias, source), val)
                        value_constraint.add(lookup, AND)
                    root_constraint.add(value_constraint, OR)
            else:
                root_constraint.add(
                    SubqueryConstraint(
                        self.lhs.alias, [target.column for target in self.lhs.targets],
                        [source.name for source in self.lhs.sources], self.rhs),
                    AND)
            return root_constraint.as_sql(compiler, connection)
        else:
            return super(RelatedIn, self).as_sql(compiler, connection)


class RelatedLookupMixin(object):
    def get_prep_lookup(self):
        if not isinstance(self.lhs, MultiColSource) and self.rhs_is_direct_value():
            # If we get here, we are dealing with single-column relations.
            self.rhs = get_normalized_value(self.rhs, self.lhs)[0]
            # We need to run the related field's get_prep_lookup(). Consider case
            # ForeignKey to IntegerField given value 'abc'. The ForeignKey itself
            # doesn't have validation for non-integers, so we must run validation
            # using the target field.
            if hasattr(self.lhs.output_field, 'get_path_info'):
                # Get the target field. We can safely assume there is only one
                # as we don't get to the direct value branch otherwise.
                self.rhs = self.lhs.output_field.get_path_info()[-1].target_fields[-1].get_prep_lookup(
                    self.lookup_name, self.rhs)

        return super(RelatedLookupMixin, self).get_prep_lookup()

    def as_sql(self, compiler, connection):
        if isinstance(self.lhs, MultiColSource):
            assert self.rhs_is_direct_value()
            self.rhs = get_normalized_value(self.rhs, self.lhs)
            from django.db.models.sql.where import WhereNode, AND
            root_constraint = WhereNode()
            for target, source, val in zip(self.lhs.targets, self.lhs.sources, self.rhs):
                lookup_class = target.get_lookup(self.lookup_name)
                root_constraint.add(
                    lookup_class(target.get_col(self.lhs.alias, source), val), AND)
            return root_constraint.as_sql(compiler, connection)
        return super(RelatedLookupMixin, self).as_sql(compiler, connection)


class RelatedExact(RelatedLookupMixin, Exact):
    pass


class RelatedLessThan(RelatedLookupMixin, LessThan):
    pass


class RelatedGreaterThan(RelatedLookupMixin, GreaterThan):
    pass


class RelatedGreaterThanOrEqual(RelatedLookupMixin, GreaterThanOrEqual):
    pass


class RelatedLessThanOrEqual(RelatedLookupMixin, LessThanOrEqual):
    pass