use std::{
    ops::{Add, Div},
    sync::{
        Arc,
        atomic::{AtomicBool, Ordering},
    },
};

use anyhow::anyhow;
use itertools::{Itertools, izip};
use polars::{
    frame::DataFrame,
    prelude::{AnyValue, ChunkAgg, DataType, NamedFrom, SeriesMethods},
    series::Series,
};
use ratatui::widgets::Cell;
use rayon::iter::{ParallelBridge, ParallelIterator};
use unicode_width::UnicodeWidthStr;

use crate::{AppResult, misc::jagged_vec::JaggedVec, tui::sheet::SheetSection};

use super::type_ext::HasSubsequence;

pub trait AnyValueExt {
    fn into_single_line(self) -> String;
    fn width(self, num_buffer: &mut NumBuffer) -> usize;
    fn into_multi_line(self) -> String;
    fn into_cell(self, width: usize) -> Cell<'static>;
    fn fuzzy_cmp(self, other: &str) -> bool;
}

impl AnyValueExt for AnyValue<'_> {
    fn into_single_line(self) -> String {
        match self {
            AnyValue::Null => "".to_owned(),
            AnyValue::StringOwned(v) => v.to_string(),
            AnyValue::String(v) => v.to_string(),
            AnyValue::Categorical(idx, rev_map) => {
                rev_map.cat_to_str(idx).unwrap_or_default().to_owned()
            }
            AnyValue::CategoricalOwned(idx, rev_map) => {
                rev_map.cat_to_str(idx).unwrap_or_default().to_owned()
            }
            AnyValue::Binary(buf) => format!("Blob (Length: {})", buf.len()),
            AnyValue::BinaryOwned(buf) => format!("Blob (Length: {})", buf.len()),
            _ => self.to_string(),
        }
    }

    fn width(self, num_buffer: &mut NumBuffer) -> usize {
        match self {
            AnyValue::Null => 0,
            AnyValue::Boolean(v) => {
                if v {
                    4 // true
                } else {
                    5 // false
                }
            }
            AnyValue::String(s) => s.lines().next().unwrap_or_default().width(),
            AnyValue::UInt8(u) => num_buffer.itoa.format(u).len(),
            AnyValue::UInt16(u) => num_buffer.itoa.format(u).len(),
            AnyValue::UInt32(u) => num_buffer.itoa.format(u).len(),
            AnyValue::UInt64(u) => num_buffer.itoa.format(u).len(),
            AnyValue::UInt128(u) => num_buffer.itoa.format(u).len(),
            AnyValue::Int8(i) => num_buffer.itoa.format(i).len(),
            AnyValue::Int16(i) => num_buffer.itoa.format(i).len(),
            AnyValue::Int32(i) => num_buffer.itoa.format(i).len(),
            AnyValue::Int64(i) => num_buffer.itoa.format(i).len(),
            AnyValue::Int128(i) => num_buffer.itoa.format(i).len(),
            AnyValue::Float32(f) => num_buffer.ryu.format(f).len(),
            AnyValue::Float64(f) => num_buffer.ryu.format(f).len(),
            AnyValue::Date(_) => 10, // 1970-10-10
            AnyValue::Datetime(_, _, _) | AnyValue::DatetimeOwned(_, _, _) => 19, // 2019-06-30 07:49:05
            _ => self.to_string().width(),
        }
    }

    fn into_multi_line(self) -> String {
        match self {
            AnyValue::Null => "".to_owned(),
            AnyValue::StringOwned(v) => v.to_string(),
            AnyValue::String(v) => v.to_string(),
            AnyValue::Categorical(idx, rev_map) => {
                rev_map.cat_to_str(idx).unwrap_or_default().to_owned()
            }
            AnyValue::CategoricalOwned(idx, rev_map) => {
                rev_map.cat_to_str(idx).unwrap_or_default().to_owned()
            }
            AnyValue::Binary(buf) => bytes_to_string(buf),
            AnyValue::BinaryOwned(buf) => bytes_to_string(buf),
            _ => self.to_string(),
        }
    }

    fn into_cell(self, width: usize) -> Cell<'static> {
        match self {
            AnyValue::Float32(f) => Cell::new(format!("{f:>w$.2}", w = width)),
            AnyValue::Float64(f) => Cell::new(format!("{f:>w$.2}", w = width)),
            _ => Cell::new(self.into_single_line()),
        }
    }

    fn fuzzy_cmp(self, other: &str) -> bool {
        match self {
            AnyValue::Null => false,
            AnyValue::StringOwned(pl_small_str) => pl_small_str.has_subsequence(other),
            AnyValue::String(val) => val.has_subsequence(other),
            _ => self.into_multi_line().has_subsequence(other),
        }
    }
}

#[derive(Default, Clone)]
pub struct NumBuffer {
    ryu: ryu::Buffer,
    itoa: itoa::Buffer,
}

pub trait DataFrameExt {
    fn widths(&self) -> Vec<usize>;
    fn get_sheet_sections(&self, pos: usize) -> Vec<SheetSection>;
    fn scatter_plot_data(&self, x_label: &str, y_label: &str) -> AppResult<JaggedVec<(f64, f64)>>;
    #[allow(clippy::type_complexity)]
    fn scatter_plot_data_grouped(
        &self,
        x_label: &str,
        y_label: &str,
        group_by: &str,
    ) -> AppResult<(JaggedVec<(f64, f64)>, Vec<String>)>;
    fn histogram_plot_data(&self, col: &str, buckets: usize) -> AppResult<Vec<(String, u64)>>;
}

pub trait TryMapAll {
    fn try_map_all(
        &self,
        f: impl Fn(AnyValue) -> Option<AnyValue<'static>> + Sync + Send + 'static,
    ) -> Option<Series>;
}

fn bytes_to_string(buf: impl AsRef<[u8]>) -> String {
    let buf = buf.as_ref();
    let index_width = buf.len().div(16).to_string().len();
    let index_width = if index_width % 2 == 0 {
        index_width
    } else {
        index_width + 1
    };
    format!(
        "Blob (Length: {})\n{}",
        buf.len(),
        buf.iter()
            .map(|b| format!("{b:02X}"))
            .chunks(8)
            .into_iter()
            .map(|mut chunk| chunk.join(" "))
            .chunks(2)
            .into_iter()
            .enumerate()
            .map(|(idx, mut chunk)| format!("{:0index_width$}:  {}", idx, chunk.join("   ")))
            .join("\n")
    )
}

impl DataFrameExt for DataFrame {
    fn widths(&self) -> Vec<usize> {
        self.iter().map(series_width).collect()
    }

    fn get_sheet_sections(&self, pos: usize) -> Vec<SheetSection> {
        izip!(
            self.get_column_names().into_iter(),
            self.get(pos)
                .unwrap_or_default()
                .into_iter()
                .map(AnyValueExt::into_multi_line),
            self.dtypes()
        )
        .map(|(header, content, dtype)| SheetSection::new(format!("{header} ({dtype})"), content))
        .collect_vec()
    }

    fn scatter_plot_data(&self, x_label: &str, y_label: &str) -> AppResult<JaggedVec<(f64, f64)>> {
        Ok(self
            .column(x_label)?
            .cast(&DataType::Float64)?
            .f64()?
            .iter()
            .zip(
                self.column(y_label)?
                    .cast(&DataType::Float64)?
                    .f64()?
                    .iter(),
            )
            .filter_map(|(x, y)| Some((x?, y?)))
            .collect())
    }

    fn scatter_plot_data_grouped(
        &self,
        x_label: &str,
        y_label: &str,
        group_by: &str,
    ) -> AppResult<(JaggedVec<(f64, f64)>, Vec<String>)> {
        let mut groups = Vec::new();
        let mut data = JaggedVec::new();
        for (name, df) in self
            .partition_by(vec![group_by], true)?
            .into_iter()
            .map(|df| {
                let name = df
                    .column(group_by)
                    .and_then(|column| column.get(0))
                    .map(AnyValueExt::into_single_line)
                    .unwrap_or("null".to_owned());
                (name, df)
            })
            .sorted_by(|(a, _), (b, _)| a.cmp(b))
        {
            groups.push(name);
            data.push(df.scatter_plot_data(x_label, y_label)?);
        }
        Ok((data, groups))
    }

    fn histogram_plot_data(&self, col_name: &str, buckets: usize) -> AppResult<Vec<(String, u64)>> {
        let col = self.column(col_name)?;
        match col.dtype() {
            DataType::UInt8
            | DataType::UInt16
            | DataType::UInt32
            | DataType::UInt64
            | DataType::Int8
            | DataType::Int16
            | DataType::Int32
            | DataType::Int64
            | DataType::Int128 => {
                let counts = col.as_materialized_series().value_counts(
                    true,
                    true,
                    format!("{col_name}_count").into(),
                    false,
                )?;
                if counts.height() <= buckets {
                    discrete_histogram(counts)
                } else {
                    continues_histogram(counts, buckets)
                }
            }
            DataType::Float32 | DataType::Float64 | DataType::Decimal(_, _) => continues_histogram(
                col.as_materialized_series()
                    .value_counts(true, true, "value".into(), false)?,
                buckets,
            ),
            DataType::Boolean | DataType::String => discrete_histogram(
                col.as_materialized_series()
                    .value_counts(true, true, "value".into(), false)?,
            ),
            _ => Err(anyhow!("Unsupported column type"))?,
        }
    }
}

fn series_width(series: &Series) -> usize {
    series.name().width().max(
        series
            .iter()
            .par_bridge()
            .fold_with((0_usize, NumBuffer::default()), |(width, mut buf), val| {
                (width.max(val.width(&mut buf)), buf)
            })
            .map(|(w, _)| w)
            .max()
            .unwrap_or_default(),
    )
}

impl TryMapAll for Series {
    fn try_map_all(
        &self,
        cast: impl Fn(AnyValue) -> Option<AnyValue<'static>> + Sync + Send + 'static,
    ) -> Option<Series> {
        let break_out = Arc::new(AtomicBool::new(false));
        let mut new = vec![AnyValue::Null; self.len()];
        std::thread::scope(|scope| {
            let piece_len = if self.len() > num_cpus::get() {
                self.len() / num_cpus::get()
            } else {
                1
            };
            for (idx, new_chunk) in new.chunks_mut(piece_len).enumerate() {
                let offset = (idx * piece_len) as i64;
                let break_out = break_out.clone();
                let cast = &cast;
                scope.spawn(move || {
                    let series = self.slice(offset, piece_len);
                    for (new_val, val) in new_chunk.iter_mut().zip(series.iter()) {
                        if let Some(parsed) = cast(val) {
                            *new_val = parsed;
                        } else {
                            break_out.store(true, Ordering::Relaxed);
                            break;
                        }
                        if break_out.load(Ordering::Relaxed) {
                            break;
                        }
                    }
                });
            }
        });
        (!break_out.load(Ordering::Relaxed)).then_some(Series::new(self.name().to_owned(), new))
    }
}

fn discrete_histogram(mut counts: DataFrame) -> AppResult<Vec<(String, u64)>> {
    counts.rechunk_mut();
    Ok(counts[0]
        .as_materialized_series()
        .iter()
        .map(AnyValue::into_single_line)
        .zip(counts[1].as_materialized_series().u32()?.iter())
        .map(|(v, c)| (v, c.unwrap_or_default() as u64))
        .collect_vec())
}

fn continues_histogram(counts: DataFrame, buckets: usize) -> AppResult<Vec<(String, u64)>> {
    let casted = counts[0].cast(&DataType::Float64)?;
    let arr = casted.f64()?;
    let (min, max) = arr.min_max().ok_or(anyhow!("No value found"))?;
    let width = (max - min) / (buckets as f64);
    let counts = arr
        .iter()
        .flatten()
        .zip(counts[1].as_materialized_series().u32()?.iter().flatten())
        .fold(vec![0; buckets], |mut buckets, (v, c)| {
            let idx = (((v - min) / width) as usize).min(buckets.len().saturating_sub(1));
            buckets[idx] += c;
            buckets
        });
    let label_len = format!("{max:.2}").len();
    Ok(counts
        .into_iter()
        .enumerate()
        .map(|(idx, r)| {
            let start = (idx as f64) * width + min;
            let end = (idx.add(1) as f64) * width + min;
            (
                format!(" {start:>w$.2} - {end:>w$.2}", w = label_len),
                r as u64,
            )
        })
        .collect())
}
