ce5be947b4
Revert fixture fixes for postgres Use postgres connection string with spaces instead of url Fix label order Make postgres tests in ci less verbose Add sequence update script Skip resets in postgres Remove option to skip resets in postgres Make postgres tests in ci verboseq Update test fixtures database Fix file tests on postgres Add postgres options to sample config Make sure tests init test fixtures before running the actual tests Fix issues with IDs too big to fit in an int Fix duplicate auto incremented IDs Refactor / Fix team tests Refactor team member tests Fix team member create Fix label test Fix getting labels Fix test fixtures for postgresql Fix connection string params Disable ssl mode on postgres integration tests Disable ssl mode on postgres tests Use sprintf to create the connection string for postgresql fixup! Add postgres support Add postgres support Added generate as a make dependency for make build Clarify docs on building Co-authored-by: kolaente <k@knt.li> Co-authored-by: Jan Tojnar <jtojnar@gmail.com> Reviewed-on: https://kolaente.dev/vikunja/api/pulls/135
296 lines
6.6 KiB
Go
296 lines
6.6 KiB
Go
package testfixtures
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"strings"
|
|
)
|
|
|
|
type postgreSQL struct {
|
|
baseHelper
|
|
|
|
useAlterConstraint bool
|
|
skipResetSequences bool
|
|
resetSequencesTo int64
|
|
|
|
tables []string
|
|
sequences []string
|
|
nonDeferrableConstraints []pgConstraint
|
|
tablesChecksum map[string]string
|
|
}
|
|
|
|
type pgConstraint struct {
|
|
tableName string
|
|
constraintName string
|
|
}
|
|
|
|
func (h *postgreSQL) init(db *sql.DB) error {
|
|
var err error
|
|
|
|
h.tables, err = h.tableNames(db)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
h.sequences, err = h.getSequences(db)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
h.nonDeferrableConstraints, err = h.getNonDeferrableConstraints(db)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (*postgreSQL) paramType() int {
|
|
return paramTypeDollar
|
|
}
|
|
|
|
func (*postgreSQL) databaseName(q queryable) (string, error) {
|
|
var dbName string
|
|
err := q.QueryRow("SELECT current_database()").Scan(&dbName)
|
|
return dbName, err
|
|
}
|
|
|
|
func (h *postgreSQL) tableNames(q queryable) ([]string, error) {
|
|
var tables []string
|
|
|
|
sql := `
|
|
SELECT pg_namespace.nspname || '.' || pg_class.relname
|
|
FROM pg_class
|
|
INNER JOIN pg_namespace ON pg_namespace.oid = pg_class.relnamespace
|
|
WHERE pg_class.relkind = 'r'
|
|
AND pg_namespace.nspname NOT IN ('pg_catalog', 'information_schema')
|
|
AND pg_namespace.nspname NOT LIKE 'pg_toast%'
|
|
AND pg_namespace.nspname NOT LIKE '\_timescaledb%';
|
|
`
|
|
rows, err := q.Query(sql)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var table string
|
|
if err = rows.Scan(&table); err != nil {
|
|
return nil, err
|
|
}
|
|
tables = append(tables, table)
|
|
}
|
|
if err = rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
return tables, nil
|
|
}
|
|
|
|
func (h *postgreSQL) getSequences(q queryable) ([]string, error) {
|
|
const sql = `
|
|
SELECT pg_namespace.nspname || '.' || pg_class.relname AS sequence_name
|
|
FROM pg_class
|
|
INNER JOIN pg_namespace ON pg_namespace.oid = pg_class.relnamespace
|
|
WHERE pg_class.relkind = 'S'
|
|
AND pg_namespace.nspname NOT LIKE '\_timescaledb%'
|
|
`
|
|
|
|
rows, err := q.Query(sql)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var sequences []string
|
|
for rows.Next() {
|
|
var sequence string
|
|
if err = rows.Scan(&sequence); err != nil {
|
|
return nil, err
|
|
}
|
|
sequences = append(sequences, sequence)
|
|
}
|
|
if err = rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
return sequences, nil
|
|
}
|
|
|
|
func (*postgreSQL) getNonDeferrableConstraints(q queryable) ([]pgConstraint, error) {
|
|
var constraints []pgConstraint
|
|
|
|
sql := `
|
|
SELECT table_schema || '.' || table_name, constraint_name
|
|
FROM information_schema.table_constraints
|
|
WHERE constraint_type = 'FOREIGN KEY'
|
|
AND is_deferrable = 'NO'
|
|
AND table_schema NOT LIKE '\_timescaledb%'
|
|
`
|
|
rows, err := q.Query(sql)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
var constraint pgConstraint
|
|
if err = rows.Scan(&constraint.tableName, &constraint.constraintName); err != nil {
|
|
return nil, err
|
|
}
|
|
constraints = append(constraints, constraint)
|
|
}
|
|
if err = rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
return constraints, nil
|
|
}
|
|
|
|
func (h *postgreSQL) disableTriggers(db *sql.DB, loadFn loadFunction) (err error) {
|
|
defer func() {
|
|
// re-enable triggers after load
|
|
var sql string
|
|
for _, table := range h.tables {
|
|
sql += fmt.Sprintf("ALTER TABLE %s ENABLE TRIGGER ALL;", h.quoteKeyword(table))
|
|
}
|
|
if _, err2 := db.Exec(sql); err2 != nil && err == nil {
|
|
err = err2
|
|
}
|
|
}()
|
|
|
|
tx, err := db.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var sql string
|
|
for _, table := range h.tables {
|
|
sql += fmt.Sprintf("ALTER TABLE %s DISABLE TRIGGER ALL;", h.quoteKeyword(table))
|
|
}
|
|
if _, err = tx.Exec(sql); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err = loadFn(tx); err != nil {
|
|
tx.Rollback()
|
|
return err
|
|
}
|
|
|
|
return tx.Commit()
|
|
}
|
|
|
|
func (h *postgreSQL) makeConstraintsDeferrable(db *sql.DB, loadFn loadFunction) (err error) {
|
|
defer func() {
|
|
// ensure constraint being not deferrable again after load
|
|
var sql string
|
|
for _, constraint := range h.nonDeferrableConstraints {
|
|
sql += fmt.Sprintf("ALTER TABLE %s ALTER CONSTRAINT %s NOT DEFERRABLE;", h.quoteKeyword(constraint.tableName), h.quoteKeyword(constraint.constraintName))
|
|
}
|
|
if _, err2 := db.Exec(sql); err2 != nil && err == nil {
|
|
err = err2
|
|
}
|
|
}()
|
|
|
|
var sql string
|
|
for _, constraint := range h.nonDeferrableConstraints {
|
|
sql += fmt.Sprintf("ALTER TABLE %s ALTER CONSTRAINT %s DEFERRABLE;", h.quoteKeyword(constraint.tableName), h.quoteKeyword(constraint.constraintName))
|
|
}
|
|
if _, err := db.Exec(sql); err != nil {
|
|
return err
|
|
}
|
|
|
|
tx, err := db.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
if _, err = tx.Exec("SET CONSTRAINTS ALL DEFERRED"); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err = loadFn(tx); err != nil {
|
|
return err
|
|
}
|
|
|
|
return tx.Commit()
|
|
}
|
|
|
|
func (h *postgreSQL) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) (err error) {
|
|
// ensure sequences being reset after load
|
|
if !h.skipResetSequences {
|
|
defer func() {
|
|
if err2 := h.resetSequences(db); err2 != nil && err == nil {
|
|
err = err2
|
|
}
|
|
}()
|
|
}
|
|
|
|
if h.useAlterConstraint {
|
|
return h.makeConstraintsDeferrable(db, loadFn)
|
|
}
|
|
return h.disableTriggers(db, loadFn)
|
|
}
|
|
|
|
func (h *postgreSQL) resetSequences(db *sql.DB) error {
|
|
resetSequencesTo := h.resetSequencesTo
|
|
if resetSequencesTo == 0 {
|
|
resetSequencesTo = 10000
|
|
}
|
|
|
|
for _, sequence := range h.sequences {
|
|
_, err := db.Exec(fmt.Sprintf("SELECT SETVAL('%s', %d)", sequence, resetSequencesTo))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (h *postgreSQL) isTableModified(q queryable, tableName string) (bool, error) {
|
|
checksum, err := h.getChecksum(q, tableName)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
oldChecksum := h.tablesChecksum[tableName]
|
|
|
|
return oldChecksum == "" || checksum != oldChecksum, nil
|
|
}
|
|
|
|
func (h *postgreSQL) afterLoad(q queryable) error {
|
|
if h.tablesChecksum != nil {
|
|
return nil
|
|
}
|
|
|
|
h.tablesChecksum = make(map[string]string, len(h.tables))
|
|
for _, t := range h.tables {
|
|
checksum, err := h.getChecksum(q, t)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
h.tablesChecksum[t] = checksum
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (h *postgreSQL) getChecksum(q queryable, tableName string) (string, error) {
|
|
sqlStr := fmt.Sprintf(`
|
|
SELECT md5(CAST((array_agg(t.*)) AS TEXT))
|
|
FROM %s AS t
|
|
`,
|
|
h.quoteKeyword(tableName),
|
|
)
|
|
|
|
var checksum sql.NullString
|
|
if err := q.QueryRow(sqlStr).Scan(&checksum); err != nil {
|
|
return "", err
|
|
}
|
|
return checksum.String, nil
|
|
}
|
|
|
|
func (*postgreSQL) quoteKeyword(s string) string {
|
|
parts := strings.Split(s, ".")
|
|
for i, p := range parts {
|
|
parts[i] = fmt.Sprintf(`"%s"`, p)
|
|
}
|
|
return strings.Join(parts, ".")
|
|
}
|