import SimpleITK as sitk
import sys
import os
import numpy as np

def convert_to_grayscale(input_filename):
    # Load the color DICOM image
    image = sitk.ReadImage(input_filename)
    
    # Ensure the image has multiple channels
    if image.GetNumberOfComponentsPerPixel() < 3:
        print("Error: The image does not have multiple color channels.")
        sys.exit(1)
    
    # Convert to numpy array
    image_array = sitk.GetArrayFromImage(image)
    
    # Compute luminance using standard formula: 0.2989*R + 0.5870*G + 0.1140*B
    grayscale_array = (0.2989 * image_array[..., 0] +
                       0.5870 * image_array[..., 1] +
                       0.1140 * image_array[..., 2])
    
    # Convert back to SimpleITK image
    grayscale_image = sitk.GetImageFromArray(grayscale_array)
    grayscale_image.CopyInformation(image)  # Preserve spatial metadata
    
    # Create output filename
    base, ext = os.path.splitext(input_filename)
    output_filename = f"{base}_grayscale.nrrd"
    
    # Save the grayscale image
    sitk.WriteImage(grayscale_image, output_filename)
    print(f"Grayscale image saved as: {output_filename}")

if __name__ == "__main__":
    if len(sys.argv) != 2:
        print("Usage: python script.py <input_dicom_file>")
        sys.exit(1)
    
    input_file = sys.argv[1]
    convert_to_grayscale(input_file)
