{-# LANGUAGE RecordWildCards, TupleSections #-}
module Main (main) where

import qualified Paths_ghc_lib as GHC

import GHC.Driver.Monad
import GHC
    ( runGhc, defaultErrorHandler
    , setTargets
    , parseModule, typecheckModule
    , TypecheckedModule(..)
    )

import GHC.Driver.Session
import GHC.Unit.Module.Name
import GHC.Unit.State
import GHC.Unit.Types
import GHC.Utils.Error
import GHC.Data.Bag
import GHC.Driver.Make
import GHC.Driver.Phases
import GHC.Unit.Module.Env
import GHC.Driver.Types
import GHC.Hs.Instances ()
import GHC.Plugins
import GHC.Tc.Types
import qualified Data.Map as M
import Data.Time
import GHC.Iface.Make
import GHC.Utils.Fingerprint

import GHC.Data.IOEnv
import GHC.Tc.Utils.Monad (updateEps_)
import GHC.Core.InstEnv
import GHC.Core.FamInstEnv


import Control.Monad (unless)
import Data.Either (partitionEithers)
import System.Exit
import System.FilePath

import qualified GHC.Data.EnumSet as EnumSet
import qualified GHC.LanguageExtensions.Type as Ext

main :: IO ()
main = defaultErrorHandler defaultFatalMessager defaultFlushOut $ do
    libDir <- GHC.getDataDir
    runGhc (Just libDir) $ withSourceErrors $ do
        setup

        loadModule =<< prepareModule "Class"
        loadModule =<< prepareModule "Instance"
        loadModule =<< prepareModule "Use"

withSourceErrors :: (GhcMonad m) => m a -> m a
withSourceErrors = handleSourceError $ \e -> do
    printException e
    liftIO $ exitWith $ ExitFailure 1

setup :: (GhcMonad m) => m ()
setup = do
    dflags <- getSessionDynFlags
    dflags <- return $ gopt_set dflags Opt_NoTypeableBinds
    dflags <- return $ dflags
        { hscTarget = HscNothing
        , mainModIs = noMainModule
        , packageDBFlags = [PackageDB $ PkgDbPath primPkgDb, ClearPackageDBs]
        }

    dflags <- liftIO $ initUnits dflags
    modifySession $ \env -> env{ hsc_dflags = dflags }
    invalidateModSummaryCache
  where
    primPkgDb = "/home/mi/.ghcup/ghc/9.0.1/lib/ghc-9.0.1/package.conf.d"

noMainModule :: Module
noMainModule = mkModule HoleUnit $ mkModuleName "Main"

withUnitState :: UnitState -> ModSummary -> ModSummary
withUnitState us ms = ms
    { ms_hspp_opts = (ms_hspp_opts ms)
        { unitState = us
        }
    }

noPrelude :: HscEnv -> HscEnv
noPrelude env = env
    { hsc_dflags = let dflags = hsc_dflags env in dflags
        { extensionFlags = EnumSet.delete Ext.ImplicitPrelude $ extensionFlags dflags
        }
    }

loadModule :: (GhcMonad m) => ModSummary -> m ()
loadModule ms = do
    liftIO $ putStrLn $ unwords ["loadModule", moduleNameString . moduleName . ms_mod $ ms]

    -- Loading dependencies has changed the unitState, so we
    -- need to refresh the one stored in the ModSummary
    ms <- do
        us <- unitState . hsc_dflags <$> getSession
        return $ withUnitState us ms

    (tmod, iface, details) <- prepareSource ms
    registerModule iface details

    -- It seems to me that this is what should register the
    -- instances. Alas, calling or not calling registerToEps here
    -- doesn't make a difference...
    registerToEps iface details

prepareModule :: (GhcMonad m) => String -> m ModSummary
prepareModule modName = do
    let mod = mkModuleName modName
        target = resolve mod
    setTargets [target]

    -- Anything already in the provided map should be left as-is
    providers <- do
        env <- getSession
        return $ moduleNameProvidersMap . unitState . hsc_dflags $ env
    let exclude = M.keys providers

    (errs, mss) <- do
        env <- getSession
        env <- return $ noPrelude env
        liftIO $ partitionEithers <$> downsweep env [] exclude False
    reportErrors errs

    let menv = mkModuleEnv [(ms_mod ms, ms) | ms <- mss]

    let Just ms = lookupModuleEnv menv $ mkModule mainUnit mod
    return ms

reportErrors :: (GhcMonad m) => [ErrorMessages] -> m ()
reportErrors errs = do
    errs <- return $ unionManyBags errs
    unless (isEmptyBag errs) $ throwErrors errs

prepareSource
    :: (GhcMonad m)
    => ModSummary
    -> m (TypecheckedModule, ModIface, ModDetails)
prepareSource ms = do
    pmod <- parseModule ms
    tmod <- typecheckModule pmod

    env <- getSession
    let (tcg, details) = tm_internals_ tmod
    iface <- liftIO $ mkModIface env tcg details
    return (tmod, iface, details)

resolve :: ModuleName -> Target
resolve mod = mkTarget $ "input" </> path <.> "src"
  where
    path = moduleNameSlashes mod

    mkTarget filePath = Target
        { targetId = TargetFile filePath (Just $ Cpp HsSrcFile)
        , targetAllowObjCode = False
        , targetContents = Nothing
        }

invalidateModSummaryCache :: (GhcMonad m) => m ()
invalidateModSummaryCache = modifySession $ \env -> env
    { hsc_mod_graph = invalidateMG (hsc_mod_graph env)
    }
  where
    invalidateMG = mapMG invalidateMS
    invalidateMS ms = ms{ ms_hs_date = addUTCTime (-1) (ms_hs_date ms) }

mkModIface :: HscEnv -> TcGblEnv -> ModDetails -> IO ModIface
mkModIface hsc_env tcg mod_details = do
    -- Of course, we should be using mkIfaceTc here directly. However,
    -- that leads to GHC trying to compute fingerprints on its own,
    -- which involves loading the .hi files that we don't generate:
    --
    -- Bad interface file: input/Instance.hi
    -- input/Instance.hi: openBinaryFile: does not exist (No such file or directory)
    --
    -- So instead, we patch GHC to make mkPartialIfaceTc public and
    -- use it directly, and use dummy fingerprints for now.

    partial <- mkPartialIfaceTc hsc_env Sf_Ignore mod_details tcg
    putStrLn $ showPpr (hsc_dflags hsc_env) $
        let Deps{..} = mi_deps partial
        in (dep_mods, dep_pkgs, dep_orphs)
    return $ fill partial
  where
    empty = emptyFullModIface (tcg_mod tcg)

    fill :: PartialModIface -> ModIface
    fill partial = partial
        { mi_decls = map (fingerprint0,) (mi_decls partial)
        , mi_final_exts = (mi_final_exts empty)
            { mi_fix_fn = mkIfaceFixCache (mi_fixities partial)
            }
        }

modifyUnitState :: (UnitState -> UnitState) -> HscEnv -> HscEnv
modifyUnitState f env = env
    { hsc_dflags = let dflags = hsc_dflags env in dflags
        { unitState = f (unitState dflags)
        }
    }

registerModule :: (GhcMonad m) => ModIface -> ModDetails -> m ()
registerModule iface details = modifySession $ extendHpt . addModule
  where
    mod_info = HomeModInfo iface details Nothing

    mod = mi_module iface
    modOrig = ModOrigin (Just True) [] [] True

    addModule = modifyUnitState $ \us -> us
        { moduleNameProvidersMap = M.insert (moduleName mod) (M.singleton mod modOrig) $ moduleNameProvidersMap us
        }

    extendHpt env
        | isHomeUnit = env{ hsc_HPT = hpt, hsc_type_env_var = Nothing }
        | otherwise = env
      where
        hpt = addToHpt (hsc_HPT env) (moduleName mod) mod_info
        isHomeUnit = toUnitId (moduleUnit mod) == homeUnitId (hsc_dflags env)

registerToEps :: (GhcMonad m) => ModIface -> ModDetails -> m ()
registerToEps iface details@ModDetails{..} = do
    env <- getSession
    liftIO $ runIOEnv (Env env '\0' () ()) $ updateEps_ extendEps
  where
    mod = mi_module iface

    extendEps :: ExternalPackageState -> ExternalPackageState
    extendEps eps = eps
        { eps_PIT = extendModuleEnv (eps_PIT eps) mod iface
        , eps_PTE = plusTypeEnv (eps_PTE eps) md_types
        , eps_rule_base = extendRuleBaseList (eps_rule_base eps) md_rules
        , eps_complete_matches = extendCompleteMatchMap (eps_complete_matches eps) md_complete_sigs
        , eps_inst_env = extendInstEnvList (eps_inst_env eps) md_insts
        , eps_fam_inst_env = extendFamInstEnvList (eps_fam_inst_env eps) md_fam_insts
        , eps_ann_env = extendAnnEnvList (eps_ann_env eps) md_anns
        , eps_mod_fam_inst_env =
                let fam_inst_env = extendFamInstEnvList emptyFamInstEnv md_fam_insts
                in extendModuleEnv (eps_mod_fam_inst_env eps) mod fam_inst_env
        , eps_stats = addEpsInStats (eps_stats eps) (length $ typeEnvElts md_types) (length md_insts) (length md_rules)
        }
