Arthur de Jong

Open Source / Free Software developer

summaryrefslogtreecommitdiffstats
path: root/django/db/models/functions.py
blob: 4cbcc9c899975d5e13ee3064b804260e0202a609 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
"""
Classes that represent database functions.
"""
from django.db.models import (
    DateTimeField, Func, IntegerField, Transform, Value,
)


class Coalesce(Func):
    """
    Chooses, from left to right, the first non-null expression and returns it.
    """
    function = 'COALESCE'

    def __init__(self, *expressions, **extra):
        if len(expressions) < 2:
            raise ValueError('Coalesce must take at least two expressions')
        super(Coalesce, self).__init__(*expressions, **extra)

    def as_oracle(self, compiler, connection):
        # we can't mix TextField (NCLOB) and CharField (NVARCHAR), so convert
        # all fields to NCLOB when we expect NCLOB
        if self.output_field.get_internal_type() == 'TextField':
            class ToNCLOB(Func):
                function = 'TO_NCLOB'

            expressions = [
                ToNCLOB(expression) for expression in self.get_source_expressions()]
            self.set_source_expressions(expressions)
        return super(Coalesce, self).as_sql(compiler, connection)


class ConcatPair(Func):
    """
    A helper class that concatenates two arguments together. This is used
    by `Concat` because not all backend databases support more than two
    arguments.
    """
    function = 'CONCAT'

    def __init__(self, left, right, **extra):
        super(ConcatPair, self).__init__(left, right, **extra)

    def as_sqlite(self, compiler, connection):
        coalesced = self.coalesce()
        coalesced.arg_joiner = ' || '
        coalesced.template = '%(expressions)s'
        return super(ConcatPair, coalesced).as_sql(compiler, connection)

    def as_mysql(self, compiler, connection):
        # Use CONCAT_WS with an empty separator so that NULLs are ignored.
        self.function = 'CONCAT_WS'
        self.template = "%(function)s('', %(expressions)s)"
        return super(ConcatPair, self).as_sql(compiler, connection)

    def coalesce(self):
        # null on either side results in null for expression, wrap with coalesce
        c = self.copy()
        expressions = [
            Coalesce(expression, Value('')) for expression in c.get_source_expressions()
        ]
        c.set_source_expressions(expressions)
        return c


class Concat(Func):
    """
    Concatenates text fields together. Backends that result in an entire
    null expression when any arguments are null will wrap each argument in
    coalesce functions to ensure we always get a non-null result.
    """
    function = None
    template = "%(expressions)s"

    def __init__(self, *expressions, **extra):
        if len(expressions) < 2:
            raise ValueError('Concat must take at least two expressions')
        paired = self._paired(expressions)
        super(Concat, self).__init__(paired, **extra)

    def _paired(self, expressions):
        # wrap pairs of expressions in successive concat functions
        # exp = [a, b, c, d]
        # -> ConcatPair(a, ConcatPair(b, ConcatPair(c, d))))
        if len(expressions) == 2:
            return ConcatPair(*expressions)
        return ConcatPair(expressions[0], self._paired(expressions[1:]))


class Greatest(Func):
    """
    Chooses the maximum expression and returns it.

    If any expression is null the return value is database-specific:
    On Postgres, the maximum not-null expression is returned.
    On MySQL, Oracle, and SQLite, if any expression is null, null is returned.
    """
    function = 'GREATEST'

    def __init__(self, *expressions, **extra):
        if len(expressions) < 2:
            raise ValueError('Greatest must take at least two expressions')
        super(Greatest, self).__init__(*expressions, **extra)

    def as_sqlite(self, compiler, connection):
        """Use the MAX function on SQLite."""
        return super(Greatest, self).as_sql(compiler, connection, function='MAX')


class Least(Func):
    """
    Chooses the minimum expression and returns it.

    If any expression is null the return value is database-specific:
    On Postgres, the minimum not-null expression is returned.
    On MySQL, Oracle, and SQLite, if any expression is null, null is returned.
    """
    function = 'LEAST'

    def __init__(self, *expressions, **extra):
        if len(expressions) < 2:
            raise ValueError('Least must take at least two expressions')
        super(Least, self).__init__(*expressions, **extra)

    def as_sqlite(self, compiler, connection):
        """Use the MIN function on SQLite."""
        return super(Least, self).as_sql(compiler, connection, function='MIN')


class Length(Transform):
    """Returns the number of characters in the expression"""
    function = 'LENGTH'
    lookup_name = 'length'

    def __init__(self, expression, **extra):
        output_field = extra.pop('output_field', IntegerField())
        super(Length, self).__init__(expression, output_field=output_field, **extra)

    def as_mysql(self, compiler, connection):
        self.function = 'CHAR_LENGTH'
        return super(Length, self).as_sql(compiler, connection)


class Lower(Transform):
    function = 'LOWER'
    lookup_name = 'lower'


class Now(Func):
    template = 'CURRENT_TIMESTAMP'

    def __init__(self, output_field=None, **extra):
        if output_field is None:
            output_field = DateTimeField()
        super(Now, self).__init__(output_field=output_field, **extra)

    def as_postgresql(self, compiler, connection):
        # Postgres' CURRENT_TIMESTAMP means "the time at the start of the
        # transaction". We use STATEMENT_TIMESTAMP to be cross-compatible with
        # other databases.
        self.template = 'STATEMENT_TIMESTAMP()'
        return self.as_sql(compiler, connection)


class Substr(Func):
    function = 'SUBSTRING'

    def __init__(self, expression, pos, length=None, **extra):
        """
        expression: the name of a field, or an expression returning a string
        pos: an integer > 0, or an expression returning an integer
        length: an optional number of characters to return
        """
        if not hasattr(pos, 'resolve_expression'):
            if pos < 1:
                raise ValueError("'pos' must be greater than 0")
            pos = Value(pos)
        expressions = [expression, pos]
        if length is not None:
            if not hasattr(length, 'resolve_expression'):
                length = Value(length)
            expressions.append(length)
        super(Substr, self).__init__(*expressions, **extra)

    def as_sqlite(self, compiler, connection):
        self.function = 'SUBSTR'
        return super(Substr, self).as_sql(compiler, connection)

    def as_oracle(self, compiler, connection):
        self.function = 'SUBSTR'
        return super(Substr, self).as_sql(compiler, connection)


class Upper(Transform):
    function = 'UPPER'
    lookup_name = 'upper'