ce5be947b4
Revert fixture fixes for postgres Use postgres connection string with spaces instead of url Fix label order Make postgres tests in ci less verbose Add sequence update script Skip resets in postgres Remove option to skip resets in postgres Make postgres tests in ci verboseq Update test fixtures database Fix file tests on postgres Add postgres options to sample config Make sure tests init test fixtures before running the actual tests Fix issues with IDs too big to fit in an int Fix duplicate auto incremented IDs Refactor / Fix team tests Refactor team member tests Fix team member create Fix label test Fix getting labels Fix test fixtures for postgresql Fix connection string params Disable ssl mode on postgres integration tests Disable ssl mode on postgres tests Use sprintf to create the connection string for postgresql fixup! Add postgres support Add postgres support Added generate as a make dependency for make build Clarify docs on building Co-authored-by: kolaente <k@knt.li> Co-authored-by: Jan Tojnar <jtojnar@gmail.com> Reviewed-on: https://kolaente.dev/vikunja/api/pulls/135
602 lines
16 KiB
Go
602 lines
16 KiB
Go
package pq
|
|
|
|
import (
|
|
"bytes"
|
|
"database/sql/driver"
|
|
"encoding/binary"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"math"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/lib/pq/oid"
|
|
)
|
|
|
|
func binaryEncode(parameterStatus *parameterStatus, x interface{}) []byte {
|
|
switch v := x.(type) {
|
|
case []byte:
|
|
return v
|
|
default:
|
|
return encode(parameterStatus, x, oid.T_unknown)
|
|
}
|
|
}
|
|
|
|
func encode(parameterStatus *parameterStatus, x interface{}, pgtypOid oid.Oid) []byte {
|
|
switch v := x.(type) {
|
|
case int64:
|
|
return strconv.AppendInt(nil, v, 10)
|
|
case float64:
|
|
return strconv.AppendFloat(nil, v, 'f', -1, 64)
|
|
case []byte:
|
|
if pgtypOid == oid.T_bytea {
|
|
return encodeBytea(parameterStatus.serverVersion, v)
|
|
}
|
|
|
|
return v
|
|
case string:
|
|
if pgtypOid == oid.T_bytea {
|
|
return encodeBytea(parameterStatus.serverVersion, []byte(v))
|
|
}
|
|
|
|
return []byte(v)
|
|
case bool:
|
|
return strconv.AppendBool(nil, v)
|
|
case time.Time:
|
|
return formatTs(v)
|
|
|
|
default:
|
|
errorf("encode: unknown type for %T", v)
|
|
}
|
|
|
|
panic("not reached")
|
|
}
|
|
|
|
func decode(parameterStatus *parameterStatus, s []byte, typ oid.Oid, f format) interface{} {
|
|
switch f {
|
|
case formatBinary:
|
|
return binaryDecode(parameterStatus, s, typ)
|
|
case formatText:
|
|
return textDecode(parameterStatus, s, typ)
|
|
default:
|
|
panic("not reached")
|
|
}
|
|
}
|
|
|
|
func binaryDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} {
|
|
switch typ {
|
|
case oid.T_bytea:
|
|
return s
|
|
case oid.T_int8:
|
|
return int64(binary.BigEndian.Uint64(s))
|
|
case oid.T_int4:
|
|
return int64(int32(binary.BigEndian.Uint32(s)))
|
|
case oid.T_int2:
|
|
return int64(int16(binary.BigEndian.Uint16(s)))
|
|
case oid.T_uuid:
|
|
b, err := decodeUUIDBinary(s)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return b
|
|
|
|
default:
|
|
errorf("don't know how to decode binary parameter of type %d", uint32(typ))
|
|
}
|
|
|
|
panic("not reached")
|
|
}
|
|
|
|
func textDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} {
|
|
switch typ {
|
|
case oid.T_char, oid.T_varchar, oid.T_text:
|
|
return string(s)
|
|
case oid.T_bytea:
|
|
b, err := parseBytea(s)
|
|
if err != nil {
|
|
errorf("%s", err)
|
|
}
|
|
return b
|
|
case oid.T_timestamptz:
|
|
return parseTs(parameterStatus.currentLocation, string(s))
|
|
case oid.T_timestamp, oid.T_date:
|
|
return parseTs(nil, string(s))
|
|
case oid.T_time:
|
|
return mustParse("15:04:05", typ, s)
|
|
case oid.T_timetz:
|
|
return mustParse("15:04:05-07", typ, s)
|
|
case oid.T_bool:
|
|
return s[0] == 't'
|
|
case oid.T_int8, oid.T_int4, oid.T_int2:
|
|
i, err := strconv.ParseInt(string(s), 10, 64)
|
|
if err != nil {
|
|
errorf("%s", err)
|
|
}
|
|
return i
|
|
case oid.T_float4, oid.T_float8:
|
|
// We always use 64 bit parsing, regardless of whether the input text is for
|
|
// a float4 or float8, because clients expect float64s for all float datatypes
|
|
// and returning a 32-bit parsed float64 produces lossy results.
|
|
f, err := strconv.ParseFloat(string(s), 64)
|
|
if err != nil {
|
|
errorf("%s", err)
|
|
}
|
|
return f
|
|
}
|
|
|
|
return s
|
|
}
|
|
|
|
// appendEncodedText encodes item in text format as required by COPY
|
|
// and appends to buf
|
|
func appendEncodedText(parameterStatus *parameterStatus, buf []byte, x interface{}) []byte {
|
|
switch v := x.(type) {
|
|
case int64:
|
|
return strconv.AppendInt(buf, v, 10)
|
|
case float64:
|
|
return strconv.AppendFloat(buf, v, 'f', -1, 64)
|
|
case []byte:
|
|
encodedBytea := encodeBytea(parameterStatus.serverVersion, v)
|
|
return appendEscapedText(buf, string(encodedBytea))
|
|
case string:
|
|
return appendEscapedText(buf, v)
|
|
case bool:
|
|
return strconv.AppendBool(buf, v)
|
|
case time.Time:
|
|
return append(buf, formatTs(v)...)
|
|
case nil:
|
|
return append(buf, "\\N"...)
|
|
default:
|
|
errorf("encode: unknown type for %T", v)
|
|
}
|
|
|
|
panic("not reached")
|
|
}
|
|
|
|
func appendEscapedText(buf []byte, text string) []byte {
|
|
escapeNeeded := false
|
|
startPos := 0
|
|
var c byte
|
|
|
|
// check if we need to escape
|
|
for i := 0; i < len(text); i++ {
|
|
c = text[i]
|
|
if c == '\\' || c == '\n' || c == '\r' || c == '\t' {
|
|
escapeNeeded = true
|
|
startPos = i
|
|
break
|
|
}
|
|
}
|
|
if !escapeNeeded {
|
|
return append(buf, text...)
|
|
}
|
|
|
|
// copy till first char to escape, iterate the rest
|
|
result := append(buf, text[:startPos]...)
|
|
for i := startPos; i < len(text); i++ {
|
|
c = text[i]
|
|
switch c {
|
|
case '\\':
|
|
result = append(result, '\\', '\\')
|
|
case '\n':
|
|
result = append(result, '\\', 'n')
|
|
case '\r':
|
|
result = append(result, '\\', 'r')
|
|
case '\t':
|
|
result = append(result, '\\', 't')
|
|
default:
|
|
result = append(result, c)
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
func mustParse(f string, typ oid.Oid, s []byte) time.Time {
|
|
str := string(s)
|
|
|
|
// check for a 30-minute-offset timezone
|
|
if (typ == oid.T_timestamptz || typ == oid.T_timetz) &&
|
|
str[len(str)-3] == ':' {
|
|
f += ":00"
|
|
}
|
|
t, err := time.Parse(f, str)
|
|
if err != nil {
|
|
errorf("decode: %s", err)
|
|
}
|
|
return t
|
|
}
|
|
|
|
var errInvalidTimestamp = errors.New("invalid timestamp")
|
|
|
|
type timestampParser struct {
|
|
err error
|
|
}
|
|
|
|
func (p *timestampParser) expect(str string, char byte, pos int) {
|
|
if p.err != nil {
|
|
return
|
|
}
|
|
if pos+1 > len(str) {
|
|
p.err = errInvalidTimestamp
|
|
return
|
|
}
|
|
if c := str[pos]; c != char && p.err == nil {
|
|
p.err = fmt.Errorf("expected '%v' at position %v; got '%v'", char, pos, c)
|
|
}
|
|
}
|
|
|
|
func (p *timestampParser) mustAtoi(str string, begin int, end int) int {
|
|
if p.err != nil {
|
|
return 0
|
|
}
|
|
if begin < 0 || end < 0 || begin > end || end > len(str) {
|
|
p.err = errInvalidTimestamp
|
|
return 0
|
|
}
|
|
result, err := strconv.Atoi(str[begin:end])
|
|
if err != nil {
|
|
if p.err == nil {
|
|
p.err = fmt.Errorf("expected number; got '%v'", str)
|
|
}
|
|
return 0
|
|
}
|
|
return result
|
|
}
|
|
|
|
// The location cache caches the time zones typically used by the client.
|
|
type locationCache struct {
|
|
cache map[int]*time.Location
|
|
lock sync.Mutex
|
|
}
|
|
|
|
// All connections share the same list of timezones. Benchmarking shows that
|
|
// about 5% speed could be gained by putting the cache in the connection and
|
|
// losing the mutex, at the cost of a small amount of memory and a somewhat
|
|
// significant increase in code complexity.
|
|
var globalLocationCache = newLocationCache()
|
|
|
|
func newLocationCache() *locationCache {
|
|
return &locationCache{cache: make(map[int]*time.Location)}
|
|
}
|
|
|
|
// Returns the cached timezone for the specified offset, creating and caching
|
|
// it if necessary.
|
|
func (c *locationCache) getLocation(offset int) *time.Location {
|
|
c.lock.Lock()
|
|
defer c.lock.Unlock()
|
|
|
|
location, ok := c.cache[offset]
|
|
if !ok {
|
|
location = time.FixedZone("", offset)
|
|
c.cache[offset] = location
|
|
}
|
|
|
|
return location
|
|
}
|
|
|
|
var infinityTsEnabled = false
|
|
var infinityTsNegative time.Time
|
|
var infinityTsPositive time.Time
|
|
|
|
const (
|
|
infinityTsEnabledAlready = "pq: infinity timestamp enabled already"
|
|
infinityTsNegativeMustBeSmaller = "pq: infinity timestamp: negative value must be smaller (before) than positive"
|
|
)
|
|
|
|
// EnableInfinityTs controls the handling of Postgres' "-infinity" and
|
|
// "infinity" "timestamp"s.
|
|
//
|
|
// If EnableInfinityTs is not called, "-infinity" and "infinity" will return
|
|
// []byte("-infinity") and []byte("infinity") respectively, and potentially
|
|
// cause error "sql: Scan error on column index 0: unsupported driver -> Scan
|
|
// pair: []uint8 -> *time.Time", when scanning into a time.Time value.
|
|
//
|
|
// Once EnableInfinityTs has been called, all connections created using this
|
|
// driver will decode Postgres' "-infinity" and "infinity" for "timestamp",
|
|
// "timestamp with time zone" and "date" types to the predefined minimum and
|
|
// maximum times, respectively. When encoding time.Time values, any time which
|
|
// equals or precedes the predefined minimum time will be encoded to
|
|
// "-infinity". Any values at or past the maximum time will similarly be
|
|
// encoded to "infinity".
|
|
//
|
|
// If EnableInfinityTs is called with negative >= positive, it will panic.
|
|
// Calling EnableInfinityTs after a connection has been established results in
|
|
// undefined behavior. If EnableInfinityTs is called more than once, it will
|
|
// panic.
|
|
func EnableInfinityTs(negative time.Time, positive time.Time) {
|
|
if infinityTsEnabled {
|
|
panic(infinityTsEnabledAlready)
|
|
}
|
|
if !negative.Before(positive) {
|
|
panic(infinityTsNegativeMustBeSmaller)
|
|
}
|
|
infinityTsEnabled = true
|
|
infinityTsNegative = negative
|
|
infinityTsPositive = positive
|
|
}
|
|
|
|
/*
|
|
* Testing might want to toggle infinityTsEnabled
|
|
*/
|
|
func disableInfinityTs() {
|
|
infinityTsEnabled = false
|
|
}
|
|
|
|
// This is a time function specific to the Postgres default DateStyle
|
|
// setting ("ISO, MDY"), the only one we currently support. This
|
|
// accounts for the discrepancies between the parsing available with
|
|
// time.Parse and the Postgres date formatting quirks.
|
|
func parseTs(currentLocation *time.Location, str string) interface{} {
|
|
switch str {
|
|
case "-infinity":
|
|
if infinityTsEnabled {
|
|
return infinityTsNegative
|
|
}
|
|
return []byte(str)
|
|
case "infinity":
|
|
if infinityTsEnabled {
|
|
return infinityTsPositive
|
|
}
|
|
return []byte(str)
|
|
}
|
|
t, err := ParseTimestamp(currentLocation, str)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return t
|
|
}
|
|
|
|
// ParseTimestamp parses Postgres' text format. It returns a time.Time in
|
|
// currentLocation iff that time's offset agrees with the offset sent from the
|
|
// Postgres server. Otherwise, ParseTimestamp returns a time.Time with the
|
|
// fixed offset offset provided by the Postgres server.
|
|
func ParseTimestamp(currentLocation *time.Location, str string) (time.Time, error) {
|
|
p := timestampParser{}
|
|
|
|
monSep := strings.IndexRune(str, '-')
|
|
// this is Gregorian year, not ISO Year
|
|
// In Gregorian system, the year 1 BC is followed by AD 1
|
|
year := p.mustAtoi(str, 0, monSep)
|
|
daySep := monSep + 3
|
|
month := p.mustAtoi(str, monSep+1, daySep)
|
|
p.expect(str, '-', daySep)
|
|
timeSep := daySep + 3
|
|
day := p.mustAtoi(str, daySep+1, timeSep)
|
|
|
|
minLen := monSep + len("01-01") + 1
|
|
|
|
isBC := strings.HasSuffix(str, " BC")
|
|
if isBC {
|
|
minLen += 3
|
|
}
|
|
|
|
var hour, minute, second int
|
|
if len(str) > minLen {
|
|
p.expect(str, ' ', timeSep)
|
|
minSep := timeSep + 3
|
|
p.expect(str, ':', minSep)
|
|
hour = p.mustAtoi(str, timeSep+1, minSep)
|
|
secSep := minSep + 3
|
|
p.expect(str, ':', secSep)
|
|
minute = p.mustAtoi(str, minSep+1, secSep)
|
|
secEnd := secSep + 3
|
|
second = p.mustAtoi(str, secSep+1, secEnd)
|
|
}
|
|
remainderIdx := monSep + len("01-01 00:00:00") + 1
|
|
// Three optional (but ordered) sections follow: the
|
|
// fractional seconds, the time zone offset, and the BC
|
|
// designation. We set them up here and adjust the other
|
|
// offsets if the preceding sections exist.
|
|
|
|
nanoSec := 0
|
|
tzOff := 0
|
|
|
|
if remainderIdx < len(str) && str[remainderIdx] == '.' {
|
|
fracStart := remainderIdx + 1
|
|
fracOff := strings.IndexAny(str[fracStart:], "-+ ")
|
|
if fracOff < 0 {
|
|
fracOff = len(str) - fracStart
|
|
}
|
|
fracSec := p.mustAtoi(str, fracStart, fracStart+fracOff)
|
|
nanoSec = fracSec * (1000000000 / int(math.Pow(10, float64(fracOff))))
|
|
|
|
remainderIdx += fracOff + 1
|
|
}
|
|
if tzStart := remainderIdx; tzStart < len(str) && (str[tzStart] == '-' || str[tzStart] == '+') {
|
|
// time zone separator is always '-' or '+' (UTC is +00)
|
|
var tzSign int
|
|
switch c := str[tzStart]; c {
|
|
case '-':
|
|
tzSign = -1
|
|
case '+':
|
|
tzSign = +1
|
|
default:
|
|
return time.Time{}, fmt.Errorf("expected '-' or '+' at position %v; got %v", tzStart, c)
|
|
}
|
|
tzHours := p.mustAtoi(str, tzStart+1, tzStart+3)
|
|
remainderIdx += 3
|
|
var tzMin, tzSec int
|
|
if remainderIdx < len(str) && str[remainderIdx] == ':' {
|
|
tzMin = p.mustAtoi(str, remainderIdx+1, remainderIdx+3)
|
|
remainderIdx += 3
|
|
}
|
|
if remainderIdx < len(str) && str[remainderIdx] == ':' {
|
|
tzSec = p.mustAtoi(str, remainderIdx+1, remainderIdx+3)
|
|
remainderIdx += 3
|
|
}
|
|
tzOff = tzSign * ((tzHours * 60 * 60) + (tzMin * 60) + tzSec)
|
|
}
|
|
var isoYear int
|
|
|
|
if isBC {
|
|
isoYear = 1 - year
|
|
remainderIdx += 3
|
|
} else {
|
|
isoYear = year
|
|
}
|
|
if remainderIdx < len(str) {
|
|
return time.Time{}, fmt.Errorf("expected end of input, got %v", str[remainderIdx:])
|
|
}
|
|
t := time.Date(isoYear, time.Month(month), day,
|
|
hour, minute, second, nanoSec,
|
|
globalLocationCache.getLocation(tzOff))
|
|
|
|
if currentLocation != nil {
|
|
// Set the location of the returned Time based on the session's
|
|
// TimeZone value, but only if the local time zone database agrees with
|
|
// the remote database on the offset.
|
|
lt := t.In(currentLocation)
|
|
_, newOff := lt.Zone()
|
|
if newOff == tzOff {
|
|
t = lt
|
|
}
|
|
}
|
|
|
|
return t, p.err
|
|
}
|
|
|
|
// formatTs formats t into a format postgres understands.
|
|
func formatTs(t time.Time) []byte {
|
|
if infinityTsEnabled {
|
|
// t <= -infinity : ! (t > -infinity)
|
|
if !t.After(infinityTsNegative) {
|
|
return []byte("-infinity")
|
|
}
|
|
// t >= infinity : ! (!t < infinity)
|
|
if !t.Before(infinityTsPositive) {
|
|
return []byte("infinity")
|
|
}
|
|
}
|
|
return FormatTimestamp(t)
|
|
}
|
|
|
|
// FormatTimestamp formats t into Postgres' text format for timestamps.
|
|
func FormatTimestamp(t time.Time) []byte {
|
|
// Need to send dates before 0001 A.D. with " BC" suffix, instead of the
|
|
// minus sign preferred by Go.
|
|
// Beware, "0000" in ISO is "1 BC", "-0001" is "2 BC" and so on
|
|
bc := false
|
|
if t.Year() <= 0 {
|
|
// flip year sign, and add 1, e.g: "0" will be "1", and "-10" will be "11"
|
|
t = t.AddDate((-t.Year())*2+1, 0, 0)
|
|
bc = true
|
|
}
|
|
b := []byte(t.Format("2006-01-02 15:04:05.999999999Z07:00"))
|
|
|
|
_, offset := t.Zone()
|
|
offset %= 60
|
|
if offset != 0 {
|
|
// RFC3339Nano already printed the minus sign
|
|
if offset < 0 {
|
|
offset = -offset
|
|
}
|
|
|
|
b = append(b, ':')
|
|
if offset < 10 {
|
|
b = append(b, '0')
|
|
}
|
|
b = strconv.AppendInt(b, int64(offset), 10)
|
|
}
|
|
|
|
if bc {
|
|
b = append(b, " BC"...)
|
|
}
|
|
return b
|
|
}
|
|
|
|
// Parse a bytea value received from the server. Both "hex" and the legacy
|
|
// "escape" format are supported.
|
|
func parseBytea(s []byte) (result []byte, err error) {
|
|
if len(s) >= 2 && bytes.Equal(s[:2], []byte("\\x")) {
|
|
// bytea_output = hex
|
|
s = s[2:] // trim off leading "\\x"
|
|
result = make([]byte, hex.DecodedLen(len(s)))
|
|
_, err := hex.Decode(result, s)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
} else {
|
|
// bytea_output = escape
|
|
for len(s) > 0 {
|
|
if s[0] == '\\' {
|
|
// escaped '\\'
|
|
if len(s) >= 2 && s[1] == '\\' {
|
|
result = append(result, '\\')
|
|
s = s[2:]
|
|
continue
|
|
}
|
|
|
|
// '\\' followed by an octal number
|
|
if len(s) < 4 {
|
|
return nil, fmt.Errorf("invalid bytea sequence %v", s)
|
|
}
|
|
r, err := strconv.ParseInt(string(s[1:4]), 8, 9)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not parse bytea value: %s", err.Error())
|
|
}
|
|
result = append(result, byte(r))
|
|
s = s[4:]
|
|
} else {
|
|
// We hit an unescaped, raw byte. Try to read in as many as
|
|
// possible in one go.
|
|
i := bytes.IndexByte(s, '\\')
|
|
if i == -1 {
|
|
result = append(result, s...)
|
|
break
|
|
}
|
|
result = append(result, s[:i]...)
|
|
s = s[i:]
|
|
}
|
|
}
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func encodeBytea(serverVersion int, v []byte) (result []byte) {
|
|
if serverVersion >= 90000 {
|
|
// Use the hex format if we know that the server supports it
|
|
result = make([]byte, 2+hex.EncodedLen(len(v)))
|
|
result[0] = '\\'
|
|
result[1] = 'x'
|
|
hex.Encode(result[2:], v)
|
|
} else {
|
|
// .. or resort to "escape"
|
|
for _, b := range v {
|
|
if b == '\\' {
|
|
result = append(result, '\\', '\\')
|
|
} else if b < 0x20 || b > 0x7e {
|
|
result = append(result, []byte(fmt.Sprintf("\\%03o", b))...)
|
|
} else {
|
|
result = append(result, b)
|
|
}
|
|
}
|
|
}
|
|
|
|
return result
|
|
}
|
|
|
|
// NullTime represents a time.Time that may be null. NullTime implements the
|
|
// sql.Scanner interface so it can be used as a scan destination, similar to
|
|
// sql.NullString.
|
|
type NullTime struct {
|
|
Time time.Time
|
|
Valid bool // Valid is true if Time is not NULL
|
|
}
|
|
|
|
// Scan implements the Scanner interface.
|
|
func (nt *NullTime) Scan(value interface{}) error {
|
|
nt.Time, nt.Valid = value.(time.Time)
|
|
return nil
|
|
}
|
|
|
|
// Value implements the driver Valuer interface.
|
|
func (nt NullTime) Value() (driver.Value, error) {
|
|
if !nt.Valid {
|
|
return nil, nil
|
|
}
|
|
return nt.Time, nil
|
|
}
|