// Copyright 2019 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package merge

import (
	"context"
	"errors"
	"fmt"

	"github.com/dolthub/go-mysql-server/sql"
	goerrors "gopkg.in/src-d/go-errors.v1"

	"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
	"github.com/dolthub/dolt/go/libraries/doltcore/schema"
	"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
	"github.com/dolthub/dolt/go/store/hash"
	"github.com/dolthub/dolt/go/store/types"
)

var ErrFastForward = errors.New("fast forward")
var ErrTableDeletedAndModified = errors.New("conflict: table with same name deleted and modified ")
var ErrTableDeletedAndSchemaModified = errors.New("conflict: table with same name deleted and its schema modified ")
var ErrSchemaConflict = goerrors.NewKind("schema conflict found, merge aborted. Please alter schema to prevent schema conflicts before merging: %s")

// ErrCantOverwriteConflicts is returned when there are unresolved conflicts
// and the merge produces new conflicts. Because we currently don't have a model
// to merge sets of conflicts together, we need to abort the merge at this
// point.
var ErrCantOverwriteConflicts = errors.New("existing unresolved conflicts would be" +
	" overridden by new conflicts produced by merge. Please resolve them and try again")

var ErrConflictsIncompatible = errors.New("the existing conflicts are of a different schema" +
	" than the conflicts generated by this merge. Please resolve them and try again")

var ErrMultipleViolationsForRow = errors.New("multiple violations for row not supported")

var ErrSameTblAddedTwice = goerrors.NewKind("table with same name '%s' added in 2 commits can't be merged")

func MergeCommits(ctx *sql.Context, tableResolver doltdb.TableResolver, commit, mergeCommit *doltdb.Commit, opts editor.Options) (*Result, error) {
	optCmt, err := doltdb.GetCommitAncestor(ctx, commit, mergeCommit)
	if err != nil {
		return nil, err
	}
	ancCommit, ok := optCmt.ToCommit()
	if !ok {
		// Ancestor commit should have been resolved before getting this far.
		return nil, doltdb.ErrGhostCommitRuntimeFailure
	}

	ourRoot, err := commit.GetRootValue(ctx)
	if err != nil {
		return nil, err
	}

	theirRoot, err := mergeCommit.GetRootValue(ctx)
	if err != nil {
		return nil, err
	}

	ancRoot, err := ancCommit.GetRootValue(ctx)
	if err != nil {
		return nil, err
	}

	mo := MergeOpts{
		IsCherryPick:        false,
		KeepSchemaConflicts: true,
	}
	return MergeRoots(ctx, tableResolver, ourRoot, theirRoot, ancRoot, mergeCommit, ancCommit, opts, mo)
}

type Result struct {
	Root            doltdb.RootValue
	SchemaConflicts []SchemaConflict
	Stats           map[doltdb.TableName]*MergeStats
}

func (r Result) HasSchemaConflicts() bool {
	return len(r.SchemaConflicts) > 0
}

func (r Result) HasMergeArtifacts() bool {
	if r.HasSchemaConflicts() {
		return true
	}
	for _, stats := range r.Stats {
		if stats.HasArtifacts() {
			return true
		}
	}
	return false
}

// CountOfTablesWithDataConflicts returns the number of tables in this merge result that have
// a data conflict.
func (r Result) CountOfTablesWithDataConflicts() int {
	count := 0
	for _, mergeStats := range r.Stats {
		if mergeStats.HasDataConflicts() || mergeStats.HasRootObjectConflicts() {
			count++
		}
	}
	return count
}

// CountOfTablesWithSchemaConflicts returns the number of tables in this merge result that have
// a schema conflict.
func (r Result) CountOfTablesWithSchemaConflicts() int {
	count := 0
	for _, mergeStats := range r.Stats {
		if mergeStats.HasSchemaConflicts() {
			count++
		}
	}
	return count
}

// CountOfTablesWithConstraintViolations returns the number of tables in this merge result that have
// a constraint violation.
func (r Result) CountOfTablesWithConstraintViolations() int {
	count := 0
	for _, mergeStats := range r.Stats {
		if mergeStats.HasConstraintViolations() {
			count++
		}
	}
	return count
}

func SchemaConflictTableNames(sc []SchemaConflict) (tables []doltdb.TableName) {
	tables = make([]doltdb.TableName, len(sc))
	for i := range sc {
		tables[i] = sc[i].TableName
	}
	return
}

// MergeRoots three-way merges |ourRoot|, |theirRoot|, and |ancRoot| and returns
// the merged root. If any conflicts or constraint violations are produced they
// are stored in the merged root. If |ourRoot| already contains conflicts they
// are stashed before the merge is performed. We abort the merge if the stash
// contains conflicts and we produce new conflicts. We currently don't have a
// model to merge conflicts together.
//
// Constraint violations that exist in ancestor are stashed and merged with the
// violations we detect when we diff the ancestor and the newly merged root.
//
// |theirRootIsh| is the hash of their's working set or commit. It is used to
// key any artifacts generated by this merge. |ancRootIsh| is similar and is
// used to retrieve the base value for a conflict.
func MergeRoots(
	ctx *sql.Context,
	tableResolver doltdb.TableResolver,
	ourRoot, theirRoot, ancRoot doltdb.RootValue,
	theirs, ancestor doltdb.Rootish,
	opts editor.Options,
	mergeOpts MergeOpts,
) (*Result, error) {
	var (
		conflictStash  *conflictStash
		violationStash *violationStash
		nbf            *types.NomsBinFormat
		err            error
	)

	nbf = ourRoot.VRW().Format()
	if !types.IsFormat_DOLT(nbf) {
		ourRoot, conflictStash, err = stashConflicts(ctx, ourRoot)
		if err != nil {
			return nil, err
		}
		ancRoot, violationStash, err = stashViolations(ctx, ancRoot)
		if err != nil {
			return nil, err
		}
	}

	// merge collations
	oColl, err := ourRoot.GetCollation(ctx)
	if err != nil {
		return nil, err
	}
	tColl, err := theirRoot.GetCollation(ctx)
	if err != nil {
		return nil, err
	}
	aColl, err := ancRoot.GetCollation(ctx)
	if err != nil {
		return nil, err
	}
	mergedRoot := ourRoot

	// there is a collation change
	if oColl != tColl {
		// both sides changed, and not the same, conflict
		if oColl != aColl && tColl != aColl {
			oCollName := sql.CollationID(oColl).Collation().Name
			tCollName := sql.CollationID(tColl).Collation().Name
			return nil, fmt.Errorf("database collation conflict, please resolve manually. ours: %s, theirs: %s", oCollName, tCollName)
		}
		// only their side changed, take their side
		if oColl == aColl {
			mergedRoot, err = mergedRoot.SetCollation(ctx, tColl)
			if err != nil {
				return nil, err
			}
		}
		// only our side changed, keep our side
	}

	// Make sure to pass in ourRoot as the first RootValue so that ourRoot's table names will be merged first.
	// This helps to avoid non-deterministic error result for table rename cases. Renaming a table creates two changes:
	// 1. dropping the old name table
	// 2. adding the new name table
	// Dropping the old name table will trigger delete/modify conflict, which is the preferred error case over
	// same column tag used error returned from creating the new name table.
	tblNames, err := doltdb.UnionTableNames(ctx, ourRoot, theirRoot)

	if err != nil {
		return nil, err
	}

	tblToStats := make(map[doltdb.TableName]*MergeStats)

	// Merge tables one at a time. This is done based on name. With table names from ourRoot being merged first,
	// renaming a table will return delete/modify conflict error consistently.
	// TODO: merge based on a more durable table identity that persists across renames
	merger, err := NewMerger(ourRoot, theirRoot, ancRoot, theirs, ancestor, ourRoot.VRW(), ourRoot.NodeStore())
	if err != nil {
		return nil, err
	}

	destSchemaNames, err := getDatabaseSchemaNames(ctx, ourRoot)
	if err != nil {
		return nil, err
	}

	// visitedTables holds all tables that were added, removed, or modified (basically not "unmodified")
	visitedTables := make(map[string]struct{})
	var schConflicts []SchemaConflict
	for _, tblName := range tblNames {
		mergedTable, stats, err := merger.MergeTable(ctx, tblName, opts, mergeOpts)

		if errors.Is(ErrTableDeletedAndModified, err) && doltdb.IsFullTextTable(tblName.Name) {
			// If a Full-Text table was both modified and deleted, then we want to ignore the deletion.
			// If there's a true conflict, then the parent table will catch the conflict.
			stats = &MergeStats{Operation: TableModified}
		} else if errors.Is(ErrTableDeletedAndSchemaModified, err) || errors.Is(ErrTableDeletedAndModified, err) {
			tblToStats[tblName] = &MergeStats{
				Operation:       TableModified,
				SchemaConflicts: 1,
			}
			conflict := SchemaConflict{
				TableName:            tblName,
				ModifyDeleteConflict: true,
			}
			if !mergeOpts.KeepSchemaConflicts {
				return nil, conflict
			}
			schConflicts = append(schConflicts, conflict)
			continue
		} else if err != nil {
			return nil, err
		}
		// If this table was visited during the merge, then we'll add it to the set
		if stats.Operation != TableUnmodified {
			visitedTables[tblName.Name] = struct{}{}
		}
		if doltdb.IsFullTextTable(tblName.Name) && (stats.Operation == TableModified || stats.Operation == TableRemoved) {
			// We handle removal and modification later in the rebuilding process, so we'll skip those.
			// We do not handle adding new tables, so we allow that to proceed.
			continue
		}
		if mergedTable.conflict.Count() > 0 {
			if types.IsFormat_DOLT(nbf) {
				schConflicts = append(schConflicts, mergedTable.conflict)
			} else {
				// return schema conflict as error
				return nil, mergedTable.conflict
			}
		}

		if mergedTable.table != nil {
			tblToStats[tblName] = stats

			// edge case: if we're merging a table with a schema name to a root that doesn't have that schema,
			// we implicitly create that schema on the destination root in addition to updating the list of schemas
			if tblName.Schema != "" && !destSchemaNames.Contains(tblName.Schema) {
				mergedRoot, err = mergedRoot.CreateDatabaseSchema(ctx, schema.DatabaseSchema{
					Name: tblName.Schema,
				})
				if err != nil {
					return nil, err
				}
				destSchemaNames.Add(tblName.Schema)
			}

			mergedRoot, err = mergedRoot.PutTable(ctx, tblName, mergedTable.table)
			var errTagPreviouslyUsed schema.ErrTagPrevUsed
			if errors.As(err, &errTagPreviouslyUsed) {
				return nil, fmt.Errorf("cannot merge, column %s on table %s has duplicate tag as table %s. This was likely because one of the tables is a rename of the other",
					errTagPreviouslyUsed.NewColName, errTagPreviouslyUsed.NewTableName, errTagPreviouslyUsed.OldTableName)
			}
			if err != nil {
				return nil, err
			}
			continue
		} else if mergedTable.rootObj != nil {
			tblToStats[tblName] = stats
			if stats.Operation != TableUnmodified || stats.RootObjectConflicts > 0 {
				mergedRoot, err = mergedRoot.PutRootObject(ctx, tblName, mergedTable.rootObj)
				if err != nil {
					return nil, err
				}
			}
			continue
		}

		mergedRootHasTable, err := mergedRoot.HasTable(ctx, tblName)
		if err != nil {
			return nil, err
		}

		if mergedRootHasTable {
			// Merge root deleted this table
			tblToStats[tblName] = &MergeStats{Operation: TableRemoved}

			// TODO: drop schemas as necessary
			mergedRoot, err = mergedRoot.RemoveTables(ctx, false, true, tblName)
			if err != nil {
				return nil, err
			}
		} else {
			// This is a deleted table that the merge root still has
			if stats.Operation != TableRemoved {
				panic(fmt.Sprintf("Invalid merge state for table %s. This is a bug.", tblName))
			}
			// Nothing to update, our root already has the table deleted
		}
	}

	mergedRoot, err = rebuildFullTextIndexes(ctx, mergedRoot, ourRoot, theirRoot, visitedTables)
	if err != nil {
		return nil, err
	}

	mergedFKColl, conflicts, err := ForeignKeysMerge(ctx, tableResolver, mergedRoot, ourRoot, theirRoot, ancRoot)
	if err != nil {
		return nil, err
	}
	if len(conflicts) > 0 {
		return nil, fmt.Errorf("foreign key conflicts")
	}

	mergedRoot, err = mergedRoot.PutForeignKeyCollection(ctx, mergedFKColl)
	if err != nil {
		return nil, err
	}

	h, err := merger.rightSrc.HashOf()
	if err != nil {
		return nil, err
	}

	var tableSet *doltdb.TableNameSet = nil
	if mergeOpts.RecordViolationsForTables != nil {
		tableSet = doltdb.NewCaseInsensitiveTableNameSet(nil)
		for tableName := range mergeOpts.RecordViolationsForTables {
			tableSet.Add(tableName)
		}
	}

	mergedRoot, _, err = AddForeignKeyViolations(ctx, tableResolver, mergedRoot, ancRoot, tableSet, h)
	if err != nil {
		return nil, err
	}

	if types.IsFormat_DOLT(ourRoot.VRW().Format()) {
		err = getConstraintViolationStats(ctx, mergedRoot, tblToStats)
		if err != nil {
			return nil, err
		}

		return &Result{
			Root:            mergedRoot,
			SchemaConflicts: schConflicts,
			Stats:           tblToStats,
		}, nil
	}

	mergedRoot, err = mergeCVsWithStash(ctx, mergedRoot, violationStash)
	if err != nil {
		return nil, err
	}

	err = getConstraintViolationStats(ctx, mergedRoot, tblToStats)
	if err != nil {
		return nil, err
	}

	mergedHasTableConflicts := checkForTableConflicts(tblToStats)
	if !conflictStash.Empty() && mergedHasTableConflicts {
		return nil, ErrCantOverwriteConflicts
	} else if !conflictStash.Empty() {
		mergedRoot, err = applyConflictStash(ctx, conflictStash.Stash, mergedRoot)
		if err != nil {
			return nil, err
		}
	}

	return &Result{
		Root:            mergedRoot,
		SchemaConflicts: schConflicts,
		Stats:           tblToStats,
	}, nil
}

// mergeCVsWithStash merges the table constraint violations in |stash| with |root|.
// Returns an updated root with all the merged CVs.
func mergeCVsWithStash(ctx context.Context, root doltdb.RootValue, stash *violationStash) (doltdb.RootValue, error) {
	updatedRoot := root
	for name, stashed := range stash.Stash {
		tbl, ok, err := root.GetTable(ctx, doltdb.TableName{Name: name})
		if err != nil {
			return nil, err
		}
		if !ok {
			// the table with the CVs was deleted
			continue
		}
		curr, err := tbl.GetConstraintViolations(ctx)
		if err != nil {
			return nil, err
		}
		unioned, err := types.UnionMaps(ctx, curr, stashed, func(key types.Value, currV types.Value, stashV types.Value) (types.Value, error) {
			if !currV.Equals(stashV) {
				panic(fmt.Sprintf("encountered conflict when merging constraint violations, conflicted key: %v\ncurrent value: %v\nstashed value: %v\n", key, currV, stashV))
			}
			return currV, nil
		})
		if err != nil {
			return nil, err
		}
		tbl, err = tbl.SetConstraintViolations(ctx, unioned)
		if err != nil {
			return nil, err
		}
		updatedRoot, err = root.PutTable(ctx, doltdb.TableName{Name: name}, tbl)
		if err != nil {
			return nil, err
		}
	}
	return updatedRoot, nil
}

// Checks if a table conflict occurred during the merge
func checkForTableConflicts(tblToStats map[doltdb.TableName]*MergeStats) bool {
	for _, stat := range tblToStats {
		if stat.HasConflicts() && !stat.HasRootObjectConflicts() {
			return true
		}
	}
	return false
}

// populates tblToStats with violation statistics
func getConstraintViolationStats(ctx context.Context, root doltdb.RootValue, tblToStats map[doltdb.TableName]*MergeStats) error {
	for tblName, stats := range tblToStats {
		tbl, ok, err := root.GetTable(ctx, tblName)
		if err != nil {
			return err
		}
		if ok {
			n, err := tbl.NumConstraintViolations(ctx)
			if err != nil {
				return err
			}
			stats.ConstraintViolations = int(n)
		}
	}
	return nil
}

type ArtifactStatus struct {
	SchemaConflictsTables      []string
	DataConflictTables         []string
	ConstraintViolationsTables []string
}

func (as ArtifactStatus) HasConflicts() bool {
	return len(as.DataConflictTables) > 0 || len(as.SchemaConflictsTables) > 0
}

func (as ArtifactStatus) HasConstraintViolations() bool {
	return len(as.ConstraintViolationsTables) > 0
}

// MergeWouldStompChanges returns list of table names that are stomped and the diffs map between head and working set.
func MergeWouldStompChanges(ctx context.Context, roots doltdb.Roots, mergeCommit *doltdb.Commit) ([]doltdb.TableName, map[doltdb.TableName]hash.Hash, error) {
	mergeRoot, err := mergeCommit.GetRootValue(ctx)
	if err != nil {
		return nil, nil, err
	}

	headTableHashes, err := doltdb.MapTableHashes(ctx, roots.Head)
	if err != nil {
		return nil, nil, err
	}

	workingTableHashes, err := doltdb.MapTableHashes(ctx, roots.Working)
	if err != nil {
		return nil, nil, err
	}

	mergeTableHashes, err := doltdb.MapTableHashes(ctx, mergeRoot)
	if err != nil {
		return nil, nil, err
	}

	headWorkingDiffs := diffTableHashes(headTableHashes, workingTableHashes)
	mergedHeadDiffs := diffTableHashes(headTableHashes, mergeTableHashes)

	stompedTables := make([]doltdb.TableName, 0, len(headWorkingDiffs))
	for tName := range headWorkingDiffs {
		if _, ok := mergedHeadDiffs[tName]; ok {
			// even if the working changes match the merge changes, don't allow (matches git behavior).
			stompedTables = append(stompedTables, tName)
		}
	}

	return stompedTables, headWorkingDiffs, nil
}

func diffTableHashes(headTableHashes, otherTableHashes map[doltdb.TableName]hash.Hash) map[doltdb.TableName]hash.Hash {
	diffs := make(map[doltdb.TableName]hash.Hash)
	for tName, hh := range headTableHashes {
		if h, ok := otherTableHashes[tName]; ok {
			if h != hh {
				// modification
				diffs[tName] = h
			}
		} else {
			// deletion
			diffs[tName] = hash.Hash{}
		}
	}

	for tName, h := range otherTableHashes {
		if _, ok := headTableHashes[tName]; !ok {
			// addition
			diffs[tName] = h
		}
	}

	return diffs
}
