ML Stable Diffusion

🧩 Syntax:
// For licensing see accompanying LICENSE.md file.
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
 
import Accelerate
import CoreGraphics
import CoreML
import Foundation
import NaturalLanguage
 
/// Schedulers compatible with StableDiffusionPipeline
public enum StableDiffusionScheduler {
    /// Scheduler that uses a pseudo-linear multi-step (PLMS) method
    case pndmScheduler
    /// Scheduler that uses a second order DPM-Solver++ algorithm
    case dpmSolverMultistepScheduler
}
 
/// RNG compatible with StableDiffusionPipeline
public enum StableDiffusionRNG {
    /// RNG that matches numpy implementation
    case numpyRNG
    /// RNG that matches PyTorch CPU implementation.
    case torchRNG
}
 
/// A pipeline used to generate image samples from text input using stable diffusion
///
/// This implementation matches:
/// [Hugging Face Diffusers Pipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py)
@available(iOS 16.2, macOS 13.1, *)
public struct StableDiffusionPipeline: ResourceManaging {
    
    public enum Error: String, Swift.Error {
        case startingImageProvidedWithoutEncoder
        case unsupportedOSVersion
    }
    
    /// Model to generate embeddings for tokenized input text
    var textEncoder: TextEncoderModel
 
    /// Model used to predict noise residuals given an input, diffusion time step, and conditional embedding
    var unet: Unet
 
    /// Model used to generate final image from latent diffusion process
    var decoder: Decoder
    
    /// Model used to latent space for image2image, and soon, in-painting
    var encoder: Encoder?
 
    /// Optional model for checking safety of generated image
    var safetyChecker: SafetyChecker? = nil
    
    /// Optional model used before Unet to control generated images by additonal inputs
    var controlNet: ControlNet? = nil
 
    /// Reports whether this pipeline can perform safety checks
    public var canSafetyCheck: Bool {
        safetyChecker != nil
    }
 
    /// Option to reduce memory during image generation
    ///
    /// If true, the pipeline will lazily load TextEncoder, Unet, Decoder, and SafetyChecker
    /// when needed and aggressively unload their resources after
    ///
    /// This will increase latency in favor of reducing memory
    var reduceMemory: Bool = false
 
    /// Option to use system multilingual NLContextualEmbedding as encoder
    var useMultilingualTextEncoder: Bool = false
 
    /// Optional natural language script to use for the text encoder.
    var script: Script? = nil
 
    /// Creates a pipeline using the specified models and tokenizer
    ///
    /// - Parameters:
    ///   - textEncoder: Model for encoding tokenized text
    ///   - unet: Model for noise prediction on latent samples
    ///   - decoder: Model for decoding latent sample to image
    ///   - controlNet: Optional model to control generated images by additonal inputs
    ///   - safetyChecker: Optional model for checking safety of generated images
    ///   - reduceMemory: Option to enable reduced memory mode
    /// - Returns: Pipeline ready for image generation
    public init(
        textEncoder: TextEncoderModel,
        unet: Unet,
        decoder: Decoder,
        encoder: Encoder?,
        controlNet: ControlNet? = nil,
        safetyChecker: SafetyChecker? = nil,
        reduceMemory: Bool = false
    ) {
        self.textEncoder = textEncoder
        self.unet = unet
        self.decoder = decoder
        self.encoder = encoder
        self.controlNet = controlNet
        self.safetyChecker = safetyChecker
        self.reduceMemory = reduceMemory
    }
 
    /// Creates a pipeline using the specified models and tokenizer
    ///
    /// - Parameters:
    ///   - textEncoder: Model for encoding tokenized text
    ///   - unet: Model for noise prediction on latent samples
    ///   - decoder: Model for decoding latent sample to image
    ///   - controlNet: Optional model to control generated images by additonal inputs
    ///   - safetyChecker: Optional model for checking safety of generated images
    ///   - reduceMemory: Option to enable reduced memory mode
    ///   - useMultilingualTextEncoder: Option to use system multilingual NLContextualEmbedding as encoder
    ///   - script: Optional natural language script to use for the text encoder.
    /// - Returns: Pipeline ready for image generation
    @available(iOS 17.0, macOS 14.0, *)
    public init(
        textEncoder: TextEncoderModel,
        unet: Unet,
        decoder: Decoder,
        encoder: Encoder?,
        controlNet: ControlNet? = nil,
        safetyChecker: SafetyChecker? = nil,
        reduceMemory: Bool = false,
        useMultilingualTextEncoder: Bool = false,
        script: Script? = nil
    ) {
        self.textEncoder = textEncoder
        self.unet = unet
        self.decoder = decoder
        self.encoder = encoder
        self.controlNet = controlNet
        self.safetyChecker = safetyChecker
        self.reduceMemory = reduceMemory
        self.useMultilingualTextEncoder = useMultilingualTextEncoder
        self.script = script
    }
 
    /// Load required resources for this pipeline
    ///
    /// If reducedMemory is true this will instead call prewarmResources instead
    /// and let the pipeline lazily load resources as needed
    public func loadResources() throws {
        if reduceMemory {
            try prewarmResources()
        } else {
            try unet.loadResources()
            try textEncoder.loadResources()
            try decoder.loadResources()
            try encoder?.loadResources()
            try controlNet?.loadResources()
            try safetyChecker?.loadResources()
        }
    }
 
    /// Unload the underlying resources to free up memory
    public func unloadResources() {
        textEncoder.unloadResources()
        unet.unloadResources()
        decoder.unloadResources()
        encoder?.unloadResources()
        controlNet?.unloadResources()
        safetyChecker?.unloadResources()
    }
 
    // Prewarm resources one at a time
    public func prewarmResources() throws {
        try textEncoder.prewarmResources()
        try unet.prewarmResources()
        try decoder.prewarmResources()
        try encoder?.prewarmResources()
        try controlNet?.prewarmResources()
        try safetyChecker?.prewarmResources()
    }
 
    /// Image generation using stable diffusion
    /// - Parameters:
    ///   - configuration: Image generation configuration
    ///   - progressHandler: Callback to perform after each step, stops on receiving false response
    /// - Returns: An array of `imageCount` optional images.
    ///            The images will be nil if safety checks were performed and found the result to be un-safe
    public func generateImages(
        configuration config: Configuration,
        progressHandler: (Progress) -> Bool = { _ in true }
    ) throws -> [CGImage?] {
 
        // Encode the input prompt and negative prompt
        let promptEmbedding = try textEncoder.encode(config.prompt)
        let negativePromptEmbedding = try textEncoder.encode(config.negativePrompt)
 
        if reduceMemory {
            textEncoder.unloadResources()
        }
 
        // Convert to Unet hidden state representation
        // Concatenate the prompt and negative prompt embeddings
        let concatEmbedding = MLShapedArray<Float32>(
            concatenating: [negativePromptEmbedding, promptEmbedding],
            alongAxis: 0
        )
 
        let hiddenStates = useMultilingualTextEncoder ? concatEmbedding : toHiddenStates(concatEmbedding)
 
        /// Setup schedulers
        let scheduler: [Scheduler] = (0..<config.imageCount).map { _ in
            switch config.schedulerType {
            case .pndmScheduler: return PNDMScheduler(stepCount: config.stepCount)
            case .dpmSolverMultistepScheduler: return DPMSolverMultistepScheduler(stepCount: config.stepCount)
            }
        }
 
        // Generate random latent samples from specified seed
        var latents: [MLShapedArray<Float32>] = try generateLatentSamples(configuration: config, scheduler: scheduler[0])
        if reduceMemory {
            encoder?.unloadResources()
        }
        let timestepStrength: Float? = config.mode == .imageToImage ? config.strength : nil
        
        // Convert cgImage for ControlNet into MLShapedArray
        let controlNetConds = try config.controlNetInputs.map { cgImage in
            let shapedArray = try cgImage.plannerRGBShapedArray(minValue: 0.0, maxValue: 1.0)
            return MLShapedArray(
                concatenating: [shapedArray, shapedArray],
                alongAxis: 0
            )
        }
 
        // De-noising loop
        let timeSteps: [Int] = scheduler[0].calculateTimesteps(strength: timestepStrength)
        for (step,t) in timeSteps.enumerated() {
 
            // Expand the latents for classifier-free guidance
            // and input to the Unet noise prediction model
            let latentUnetInput = latents.map {
                MLShapedArray<Float32>(concatenating: [$0, $0], alongAxis: 0)
            }
            
            // Before Unet, execute controlNet and add the output into Unet inputs
            let additionalResiduals = try controlNet?.execute(
                latents: latentUnetInput,
                timeStep: t,
                hiddenStates: hiddenStates,
                images: controlNetConds
            )
            
            // Predict noise residuals from latent samples
            // and current time step conditioned on hidden states
            var noise = try unet.predictNoise(
                latents: latentUnetInput,
                timeStep: t,
                hiddenStates: hiddenStates,
                additionalResiduals: additionalResiduals
            )
 
            noise = performGuidance(noise, config.guidanceScale)
 
            // Have the scheduler compute the previous (t-1) latent
            // sample given the predicted noise and current sample
            for i in 0..<config.imageCount {
                latents[i] = scheduler[i].step(
                    output: noise[i],
                    timeStep: t,
                    sample: latents[i]
                )
            }
 
            // Report progress
            let progress = Progress(
                pipeline: self,
                prompt: config.prompt,
                step: step,
                stepCount: timeSteps.count,
                currentLatentSamples: latents,
                configuration: config
            )
            if !progressHandler(progress) {
                // Stop if requested by handler
                return []
            }
        }
 
        if reduceMemory {
            controlNet?.unloadResources()
            unet.unloadResources()
        }
 
        // Decode the latent samples to images
        return try decodeToImages(latents, configuration: config)
    }
 
    private func randomSource(from rng: StableDiffusionRNG, seed: UInt32) -> RandomSource {
        switch rng {
        case .numpyRNG:
            return NumPyRandomSource(seed: seed)
        case .torchRNG:
            return TorchRandomSource(seed: seed)
        }
    }
 
    func generateLatentSamples(configuration config: Configuration, scheduler: Scheduler) throws -> [MLShapedArray<Float32>] {
        var sampleShape = unet.latentSampleShape
        sampleShape[0] = 1
        
        let stdev = scheduler.initNoiseSigma
        var random = randomSource(from: config.rngType, seed: config.seed)
        let samples = (0..<config.imageCount).map { _ in
            MLShapedArray<Float32>(
                converting: random.normalShapedArray(sampleShape, mean: 0.0, stdev: Double(stdev)))
        }
        if let image = config.startingImage, config.mode == .imageToImage {
            guard let encoder else {
                throw Error.startingImageProvidedWithoutEncoder
            }
            let latent = try encoder.encode(image, scaleFactor: config.encoderScaleFactor, random: &random)
            return scheduler.addNoise(originalSample: latent, noise: samples, strength: config.strength)
        }
        return samples
    }
 
    func toHiddenStates(_ embedding: MLShapedArray<Float32>) -> MLShapedArray<Float32> {
        // Unoptimized manual transpose [0, 2, None, 1]
        // e.g. From [2, 77, 768] to [2, 768, 1, 77]
        let fromShape = embedding.shape
        let stateShape = [fromShape[0],fromShape[2], 1, fromShape[1]]
        var states = MLShapedArray<Float32>(repeating: 0.0, shape: stateShape)
        for i0 in 0..<fromShape[0] {
            for i1 in 0..<fromShape[1] {
                for i2 in 0..<fromShape[2] {
                    states[scalarAt:i0,i2,0,i1] = embedding[scalarAt:i0, i1, i2]
                }
            }
        }
        return states
    }
 
    func performGuidance(_ noise: [MLShapedArray<Float32>], _ guidanceScale: Float) -> [MLShapedArray<Float32>] {
        noise.map { performGuidance($0, guidanceScale) }
    }
 
    func performGuidance(_ noise: MLShapedArray<Float32>, _ guidanceScale: Float) -> MLShapedArray<Float32> {
        var shape = noise.shape
        shape[0] = 1
        return MLShapedArray<Float>(unsafeUninitializedShape: shape) { result, _ in
            noise.withUnsafeShapedBufferPointer { scalars, _, strides in
                for i in 0 ..< result.count {
                    // unconditioned + guidance*(text - unconditioned)
                    result.initializeElement(
                        at: i,
                        to: scalars[i] + guidanceScale * (scalars[strides[0] + i] - scalars[i])
                    )
                }
            }
        }
    }
 
    func decodeToImages(_ latents: [MLShapedArray<Float32>], configuration config: Configuration) throws -> [CGImage?] {
        let images = try decoder.decode(latents, scaleFactor: config.decoderScaleFactor)
        if reduceMemory {
            decoder.unloadResources()
        }
 
        // If safety is disabled return what was decoded
        if config.disableSafety {
            return images
        }
 
        // If there is no safety checker return what was decoded
        guard let safetyChecker = safetyChecker else {
            return images
        }
 
        // Otherwise change images which are not safe to nil
        let safeImages = try images.map { image in
            try safetyChecker.isSafe(image) ? image : nil
        }
 
        if reduceMemory {
            safetyChecker.unloadResources()
        }
 
        return safeImages
    }
 
}
 
@available(iOS 16.2, macOS 13.1, *)
extension StableDiffusionPipeline {
    /// Sampling progress details
    public struct Progress {
        public let pipeline: StableDiffusionPipeline
        public let prompt: String
        public let step: Int
        public let stepCount: Int
        public let currentLatentSamples: [MLShapedArray<Float32>]
        public let configuration: Configuration
        public var isSafetyEnabled: Bool {
            pipeline.canSafetyCheck && !configuration.disableSafety
        }
        public var currentImages: [CGImage?] {
            try! pipeline.decodeToImages(currentLatentSamples, configuration: configuration)
        }
    }
}
 
// For licensing see accompanying LICENSE.md file.
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
 
import Foundation
import CoreGraphics
 
@available(iOS 16.2, macOS 13.1, *)
extension StableDiffusionPipeline {
    
    /// Tyoe of processing that will be performed to generate an image
    public enum Mode {
        case textToImage
        case imageToImage
        // case inPainting
    }
    
    /// Image generation configuration
    public struct Configuration: Hashable {
        
        /// Text prompt to guide sampling
        public var prompt: String
        /// Negative text prompt to guide sampling
        public var negativePrompt: String = ""
        /// Starting image for image2image or in-painting
        public var startingImage: CGImage? = nil
        //public var maskImage: CGImage? = nil
        public var strength: Float = 1.0
        /// Number of images to generate
        public var imageCount: Int = 1
        /// Number of inference steps to perform
        public var stepCount: Int = 50
        /// Random seed which to start generation
        public var seed: UInt32 = 0
        /// Controls the influence of the text prompt on sampling process (0=random images)
        public var guidanceScale: Float = 7.5
        /// List of Images for available ControlNet Models
        public var controlNetInputs: [CGImage] = []
        /// Safety checks are only performed if `self.canSafetyCheck && !disableSafety`
        public var disableSafety: Bool = false
        /// The type of Scheduler to use.
        public var schedulerType: StableDiffusionScheduler = .pndmScheduler
        /// The type of RNG to use
        public var rngType: StableDiffusionRNG = .numpyRNG
        /// Scale factor to use on the latent after encoding
        public var encoderScaleFactor: Float32 = 0.18215
        /// Scale factor to use on the latent before decoding
        public var decoderScaleFactor: Float32 = 0.18215
        
        /// Given the configuration, what mode will be used for generation
        public var mode: Mode {
            guard startingImage != nil else {
                return .textToImage
            }
            guard strength < 1.0 else {
                return .textToImage
            }
            return .imageToImage
        }
        
        public init(
            prompt: String
        ) {
            self.prompt = prompt
        }
        
    }
 
}
 
// For licensing see accompanying LICENSE.md file.
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
 
import Foundation
import CoreML
import NaturalLanguage
 
@available(iOS 16.2, macOS 13.1, *)
public extension StableDiffusionPipeline {
 
    struct ResourceURLs {
 
        public let textEncoderURL: URL
        public let unetURL: URL
        public let unetChunk1URL: URL
        public let unetChunk2URL: URL
        public let decoderURL: URL
        public let encoderURL: URL
        public let safetyCheckerURL: URL
        public let vocabURL: URL
        public let mergesURL: URL
        public let controlNetDirURL: URL
        public let controlledUnetURL: URL
        public let controlledUnetChunk1URL: URL
        public let controlledUnetChunk2URL: URL
        public let multilingualTextEncoderProjectionURL: URL
 
        public init(resourcesAt baseURL: URL) {
            textEncoderURL = baseURL.appending(path: "TextEncoder.mlmodelc")
            unetURL = baseURL.appending(path: "Unet.mlmodelc")
            unetChunk1URL = baseURL.appending(path: "UnetChunk1.mlmodelc")
            unetChunk2URL = baseURL.appending(path: "UnetChunk2.mlmodelc")
            decoderURL = baseURL.appending(path: "VAEDecoder.mlmodelc")
            encoderURL = baseURL.appending(path: "VAEEncoder.mlmodelc")
            safetyCheckerURL = baseURL.appending(path: "SafetyChecker.mlmodelc")
            vocabURL = baseURL.appending(path: "vocab.json")
            mergesURL = baseURL.appending(path: "merges.txt")
            controlNetDirURL = baseURL.appending(path: "controlnet")
            controlledUnetURL = baseURL.appending(path: "ControlledUnet.mlmodelc")
            controlledUnetChunk1URL = baseURL.appending(path: "ControlledUnetChunk1.mlmodelc")
            controlledUnetChunk2URL = baseURL.appending(path: "ControlledUnetChunk2.mlmodelc")
            multilingualTextEncoderProjectionURL = baseURL.appending(path: "MultilingualTextEncoderProjection.mlmodelc")
        }
    }
 
    /// Create stable diffusion pipeline using model resources at a
    /// specified URL
    ///
    /// - Parameters:
    ///   - baseURL: URL pointing to directory holding all model and tokenization resources
    ///   - controlNetModelNames: Specify ControlNet models to use in generation
    ///   - configuration: The configuration to load model resources with
    ///   - disableSafety: Load time disable of safety to save memory
    ///   - reduceMemory: Setup pipeline in reduced memory mode
    ///   - useMultilingualTextEncoder: Option to use system multilingual NLContextualEmbedding as encoder
    ///   - script: Optional natural language script to use for the text encoder.
    /// - Returns:
    ///  Pipeline ready for image generation if all  necessary resources loaded
    init(
        resourcesAt baseURL: URL,
        controlNet controlNetModelNames: [String],
        configuration config: MLModelConfiguration = .init(),
        disableSafety: Bool = false,
        reduceMemory: Bool = false,
        useMultilingualTextEncoder: Bool = false,
        script: Script? = nil
    ) throws {
 
        /// Expect URL of each resource
        let urls = ResourceURLs(resourcesAt: baseURL)
        let textEncoder: TextEncoderModel
        if useMultilingualTextEncoder {
            guard #available(macOS 14.0, iOS 17.0, *) else { throw Error.unsupportedOSVersion }
            textEncoder = MultilingualTextEncoder(
                modelAt: urls.multilingualTextEncoderProjectionURL,
                configuration: config,
                script: script ?? .latin
            )
        } else {
            let tokenizer = try BPETokenizer(mergesAt: urls.mergesURL, vocabularyAt: urls.vocabURL)
            textEncoder = TextEncoder(tokenizer: tokenizer, modelAt: urls.textEncoderURL, configuration: config)
        }
 
        // ControlNet model
        var controlNet: ControlNet? = nil
        let controlNetURLs = controlNetModelNames.map { model in
            let fileName = model + ".mlmodelc"
            return urls.controlNetDirURL.appending(path: fileName)
        }
        if !controlNetURLs.isEmpty {
            controlNet = ControlNet(modelAt: controlNetURLs, configuration: config)
        }
 
        // Unet model
        let unet: Unet
        let unetURL: URL, unetChunk1URL: URL, unetChunk2URL: URL
        
        // if ControlNet available, Unet supports additional inputs from ControlNet
        if controlNet == nil {
            unetURL = urls.unetURL
            unetChunk1URL = urls.unetChunk1URL
            unetChunk2URL = urls.unetChunk2URL
        } else {
            unetURL = urls.controlledUnetURL
            unetChunk1URL = urls.controlledUnetChunk1URL
            unetChunk2URL = urls.controlledUnetChunk2URL
        }
        
        if FileManager.default.fileExists(atPath: unetChunk1URL.path) &&
            FileManager.default.fileExists(atPath: unetChunk2URL.path) {
            unet = Unet(chunksAt: [unetChunk1URL, unetChunk2URL],
                        configuration: config)
        } else {
            unet = Unet(modelAt: unetURL, configuration: config)
        }
 
        // Image Decoder
        let decoder = Decoder(modelAt: urls.decoderURL, configuration: config)
 
        // Optional safety checker
        var safetyChecker: SafetyChecker? = nil
        if !disableSafety &&
            FileManager.default.fileExists(atPath: urls.safetyCheckerURL.path) {
            safetyChecker = SafetyChecker(modelAt: urls.safetyCheckerURL, configuration: config)
        }
        
        // Optional Image Encoder
        let encoder: Encoder?
        if FileManager.default.fileExists(atPath: urls.encoderURL.path) {
            encoder = Encoder(modelAt: urls.encoderURL, configuration: config)
        } else {
            encoder = nil
        }
 
        // Construct pipeline
        if #available(macOS 14.0, iOS 17.0, *) {
            self.init(
                textEncoder: textEncoder,
                unet: unet,
                decoder: decoder,
                encoder: encoder,
                controlNet: controlNet,
                safetyChecker: safetyChecker,
                reduceMemory: reduceMemory,
                useMultilingualTextEncoder: useMultilingualTextEncoder,
                script: script
            )
        } else {
            self.init(
                textEncoder: textEncoder,
                unet: unet,
                decoder: decoder,
                encoder: encoder,
                controlNet: controlNet,
                safetyChecker: safetyChecker,
                reduceMemory: reduceMemory
            )
        }
    }
}