// For licensing see accompanying LICENSE.md file. // Copyright (C) 2022 Apple Inc. All Rights Reserved. import Accelerate import CoreGraphics import CoreML import Foundation import NaturalLanguage /// Schedulers compatible with StableDiffusionPipeline public enum StableDiffusionScheduler { /// Scheduler that uses a pseudo-linear multi-step (PLMS) method case pndmScheduler /// Scheduler that uses a second order DPM-Solver++ algorithm case dpmSolverMultistepScheduler } /// RNG compatible with StableDiffusionPipeline public enum StableDiffusionRNG { /// RNG that matches numpy implementation case numpyRNG /// RNG that matches PyTorch CPU implementation. case torchRNG } /// A pipeline used to generate image samples from text input using stable diffusion /// /// This implementation matches: /// [Hugging Face Diffusers Pipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py) @available(iOS 16.2, macOS 13.1, *) public struct StableDiffusionPipeline: ResourceManaging { public enum Error: String, Swift.Error { case startingImageProvidedWithoutEncoder case unsupportedOSVersion } /// Model to generate embeddings for tokenized input text var textEncoder: TextEncoderModel /// Model used to predict noise residuals given an input, diffusion time step, and conditional embedding var unet: Unet /// Model used to generate final image from latent diffusion process var decoder: Decoder /// Model used to latent space for image2image, and soon, in-painting var encoder: Encoder? /// Optional model for checking safety of generated image var safetyChecker: SafetyChecker? = nil /// Optional model used before Unet to control generated images by additonal inputs var controlNet: ControlNet? = nil /// Reports whether this pipeline can perform safety checks public var canSafetyCheck: Bool { safetyChecker != nil } /// Option to reduce memory during image generation /// /// If true, the pipeline will lazily load TextEncoder, Unet, Decoder, and SafetyChecker /// when needed and aggressively unload their resources after /// /// This will increase latency in favor of reducing memory var reduceMemory: Bool = false /// Option to use system multilingual NLContextualEmbedding as encoder var useMultilingualTextEncoder: Bool = false /// Optional natural language script to use for the text encoder. var script: Script? = nil /// Creates a pipeline using the specified models and tokenizer /// /// - Parameters: /// - textEncoder: Model for encoding tokenized text /// - unet: Model for noise prediction on latent samples /// - decoder: Model for decoding latent sample to image /// - controlNet: Optional model to control generated images by additonal inputs /// - safetyChecker: Optional model for checking safety of generated images /// - reduceMemory: Option to enable reduced memory mode /// - Returns: Pipeline ready for image generation public init( textEncoder: TextEncoderModel, unet: Unet, decoder: Decoder, encoder: Encoder?, controlNet: ControlNet? = nil, safetyChecker: SafetyChecker? = nil, reduceMemory: Bool = false ) { self.textEncoder = textEncoder self.unet = unet self.decoder = decoder self.encoder = encoder self.controlNet = controlNet self.safetyChecker = safetyChecker self.reduceMemory = reduceMemory } /// Creates a pipeline using the specified models and tokenizer /// /// - Parameters: /// - textEncoder: Model for encoding tokenized text /// - unet: Model for noise prediction on latent samples /// - decoder: Model for decoding latent sample to image /// - controlNet: Optional model to control generated images by additonal inputs /// - safetyChecker: Optional model for checking safety of generated images /// - reduceMemory: Option to enable reduced memory mode /// - useMultilingualTextEncoder: Option to use system multilingual NLContextualEmbedding as encoder /// - script: Optional natural language script to use for the text encoder. /// - Returns: Pipeline ready for image generation @available(iOS 17.0, macOS 14.0, *) public init( textEncoder: TextEncoderModel, unet: Unet, decoder: Decoder, encoder: Encoder?, controlNet: ControlNet? = nil, safetyChecker: SafetyChecker? = nil, reduceMemory: Bool = false, useMultilingualTextEncoder: Bool = false, script: Script? = nil ) { self.textEncoder = textEncoder self.unet = unet self.decoder = decoder self.encoder = encoder self.controlNet = controlNet self.safetyChecker = safetyChecker self.reduceMemory = reduceMemory self.useMultilingualTextEncoder = useMultilingualTextEncoder self.script = script } /// Load required resources for this pipeline /// /// If reducedMemory is true this will instead call prewarmResources instead /// and let the pipeline lazily load resources as needed public func loadResources() throws { if reduceMemory { try prewarmResources() } else { try unet.loadResources() try textEncoder.loadResources() try decoder.loadResources() try encoder?.loadResources() try controlNet?.loadResources() try safetyChecker?.loadResources() } } /// Unload the underlying resources to free up memory public func unloadResources() { textEncoder.unloadResources() unet.unloadResources() decoder.unloadResources() encoder?.unloadResources() controlNet?.unloadResources() safetyChecker?.unloadResources() } // Prewarm resources one at a time public func prewarmResources() throws { try textEncoder.prewarmResources() try unet.prewarmResources() try decoder.prewarmResources() try encoder?.prewarmResources() try controlNet?.prewarmResources() try safetyChecker?.prewarmResources() } /// Image generation using stable diffusion /// - Parameters: /// - configuration: Image generation configuration /// - progressHandler: Callback to perform after each step, stops on receiving false response /// - Returns: An array of `imageCount` optional images. /// The images will be nil if safety checks were performed and found the result to be un-safe public func generateImages( configuration config: Configuration, progressHandler: (Progress) -> Bool = { _ in true } ) throws -> [CGImage?] { // Encode the input prompt and negative prompt let promptEmbedding = try textEncoder.encode(config.prompt) let negativePromptEmbedding = try textEncoder.encode(config.negativePrompt) if reduceMemory { textEncoder.unloadResources() } // Convert to Unet hidden state representation // Concatenate the prompt and negative prompt embeddings let concatEmbedding = MLShapedArray( concatenating: [negativePromptEmbedding, promptEmbedding], alongAxis: 0 ) let hiddenStates = useMultilingualTextEncoder ? concatEmbedding : toHiddenStates(concatEmbedding) /// Setup schedulers let scheduler: [Scheduler] = (0..] = try generateLatentSamples(configuration: config, scheduler: scheduler[0]) if reduceMemory { encoder?.unloadResources() } let timestepStrength: Float? = config.mode == .imageToImage ? config.strength : nil // Convert cgImage for ControlNet into MLShapedArray let controlNetConds = try config.controlNetInputs.map { cgImage in let shapedArray = try cgImage.plannerRGBShapedArray(minValue: 0.0, maxValue: 1.0) return MLShapedArray( concatenating: [shapedArray, shapedArray], alongAxis: 0 ) } // De-noising loop let timeSteps: [Int] = scheduler[0].calculateTimesteps(strength: timestepStrength) for (step,t) in timeSteps.enumerated() { // Expand the latents for classifier-free guidance // and input to the Unet noise prediction model let latentUnetInput = latents.map { MLShapedArray(concatenating: [$0, $0], alongAxis: 0) } // Before Unet, execute controlNet and add the output into Unet inputs let additionalResiduals = try controlNet?.execute( latents: latentUnetInput, timeStep: t, hiddenStates: hiddenStates, images: controlNetConds ) // Predict noise residuals from latent samples // and current time step conditioned on hidden states var noise = try unet.predictNoise( latents: latentUnetInput, timeStep: t, hiddenStates: hiddenStates, additionalResiduals: additionalResiduals ) noise = performGuidance(noise, config.guidanceScale) // Have the scheduler compute the previous (t-1) latent // sample given the predicted noise and current sample for i in 0.. RandomSource { switch rng { case .numpyRNG: return NumPyRandomSource(seed: seed) case .torchRNG: return TorchRandomSource(seed: seed) } } func generateLatentSamples(configuration config: Configuration, scheduler: Scheduler) throws -> [MLShapedArray] { var sampleShape = unet.latentSampleShape sampleShape[0] = 1 let stdev = scheduler.initNoiseSigma var random = randomSource(from: config.rngType, seed: config.seed) let samples = (0..( converting: random.normalShapedArray(sampleShape, mean: 0.0, stdev: Double(stdev))) } if let image = config.startingImage, config.mode == .imageToImage { guard let encoder else { throw Error.startingImageProvidedWithoutEncoder } let latent = try encoder.encode(image, scaleFactor: config.encoderScaleFactor, random: &random) return scheduler.addNoise(originalSample: latent, noise: samples, strength: config.strength) } return samples } func toHiddenStates(_ embedding: MLShapedArray) -> MLShapedArray { // Unoptimized manual transpose [0, 2, None, 1] // e.g. From [2, 77, 768] to [2, 768, 1, 77] let fromShape = embedding.shape let stateShape = [fromShape[0],fromShape[2], 1, fromShape[1]] var states = MLShapedArray(repeating: 0.0, shape: stateShape) for i0 in 0..], _ guidanceScale: Float) -> [MLShapedArray] { noise.map { performGuidance($0, guidanceScale) } } func performGuidance(_ noise: MLShapedArray, _ guidanceScale: Float) -> MLShapedArray { var shape = noise.shape shape[0] = 1 return MLShapedArray(unsafeUninitializedShape: shape) { result, _ in noise.withUnsafeShapedBufferPointer { scalars, _, strides in for i in 0 ..< result.count { // unconditioned + guidance*(text - unconditioned) result.initializeElement( at: i, to: scalars[i] + guidanceScale * (scalars[strides[0] + i] - scalars[i]) ) } } } } func decodeToImages(_ latents: [MLShapedArray], configuration config: Configuration) throws -> [CGImage?] { let images = try decoder.decode(latents, scaleFactor: config.decoderScaleFactor) if reduceMemory { decoder.unloadResources() } // If safety is disabled return what was decoded if config.disableSafety { return images } // If there is no safety checker return what was decoded guard let safetyChecker = safetyChecker else { return images } // Otherwise change images which are not safe to nil let safeImages = try images.map { image in try safetyChecker.isSafe(image) ? image : nil } if reduceMemory { safetyChecker.unloadResources() } return safeImages } } @available(iOS 16.2, macOS 13.1, *) extension StableDiffusionPipeline { /// Sampling progress details public struct Progress { public let pipeline: StableDiffusionPipeline public let prompt: String public let step: Int public let stepCount: Int public let currentLatentSamples: [MLShapedArray] public let configuration: Configuration public var isSafetyEnabled: Bool { pipeline.canSafetyCheck && !configuration.disableSafety } public var currentImages: [CGImage?] { try! pipeline.decodeToImages(currentLatentSamples, configuration: configuration) } } } // For licensing see accompanying LICENSE.md file. // Copyright (C) 2022 Apple Inc. All Rights Reserved. import Foundation import CoreGraphics @available(iOS 16.2, macOS 13.1, *) extension StableDiffusionPipeline { /// Tyoe of processing that will be performed to generate an image public enum Mode { case textToImage case imageToImage // case inPainting } /// Image generation configuration public struct Configuration: Hashable { /// Text prompt to guide sampling public var prompt: String /// Negative text prompt to guide sampling public var negativePrompt: String = "" /// Starting image for image2image or in-painting public var startingImage: CGImage? = nil //public var maskImage: CGImage? = nil public var strength: Float = 1.0 /// Number of images to generate public var imageCount: Int = 1 /// Number of inference steps to perform public var stepCount: Int = 50 /// Random seed which to start generation public var seed: UInt32 = 0 /// Controls the influence of the text prompt on sampling process (0=random images) public var guidanceScale: Float = 7.5 /// List of Images for available ControlNet Models public var controlNetInputs: [CGImage] = [] /// Safety checks are only performed if `self.canSafetyCheck && !disableSafety` public var disableSafety: Bool = false /// The type of Scheduler to use. public var schedulerType: StableDiffusionScheduler = .pndmScheduler /// The type of RNG to use public var rngType: StableDiffusionRNG = .numpyRNG /// Scale factor to use on the latent after encoding public var encoderScaleFactor: Float32 = 0.18215 /// Scale factor to use on the latent before decoding public var decoderScaleFactor: Float32 = 0.18215 /// Given the configuration, what mode will be used for generation public var mode: Mode { guard startingImage != nil else { return .textToImage } guard strength < 1.0 else { return .textToImage } return .imageToImage } public init( prompt: String ) { self.prompt = prompt } } } // For licensing see accompanying LICENSE.md file. // Copyright (C) 2022 Apple Inc. All Rights Reserved. import Foundation import CoreML import NaturalLanguage @available(iOS 16.2, macOS 13.1, *) public extension StableDiffusionPipeline { struct ResourceURLs { public let textEncoderURL: URL public let unetURL: URL public let unetChunk1URL: URL public let unetChunk2URL: URL public let decoderURL: URL public let encoderURL: URL public let safetyCheckerURL: URL public let vocabURL: URL public let mergesURL: URL public let controlNetDirURL: URL public let controlledUnetURL: URL public let controlledUnetChunk1URL: URL public let controlledUnetChunk2URL: URL public let multilingualTextEncoderProjectionURL: URL public init(resourcesAt baseURL: URL) { textEncoderURL = baseURL.appending(path: "TextEncoder.mlmodelc") unetURL = baseURL.appending(path: "Unet.mlmodelc") unetChunk1URL = baseURL.appending(path: "UnetChunk1.mlmodelc") unetChunk2URL = baseURL.appending(path: "UnetChunk2.mlmodelc") decoderURL = baseURL.appending(path: "VAEDecoder.mlmodelc") encoderURL = baseURL.appending(path: "VAEEncoder.mlmodelc") safetyCheckerURL = baseURL.appending(path: "SafetyChecker.mlmodelc") vocabURL = baseURL.appending(path: "vocab.json") mergesURL = baseURL.appending(path: "merges.txt") controlNetDirURL = baseURL.appending(path: "controlnet") controlledUnetURL = baseURL.appending(path: "ControlledUnet.mlmodelc") controlledUnetChunk1URL = baseURL.appending(path: "ControlledUnetChunk1.mlmodelc") controlledUnetChunk2URL = baseURL.appending(path: "ControlledUnetChunk2.mlmodelc") multilingualTextEncoderProjectionURL = baseURL.appending(path: "MultilingualTextEncoderProjection.mlmodelc") } } /// Create stable diffusion pipeline using model resources at a /// specified URL /// /// - Parameters: /// - baseURL: URL pointing to directory holding all model and tokenization resources /// - controlNetModelNames: Specify ControlNet models to use in generation /// - configuration: The configuration to load model resources with /// - disableSafety: Load time disable of safety to save memory /// - reduceMemory: Setup pipeline in reduced memory mode /// - useMultilingualTextEncoder: Option to use system multilingual NLContextualEmbedding as encoder /// - script: Optional natural language script to use for the text encoder. /// - Returns: /// Pipeline ready for image generation if all necessary resources loaded init( resourcesAt baseURL: URL, controlNet controlNetModelNames: [String], configuration config: MLModelConfiguration = .init(), disableSafety: Bool = false, reduceMemory: Bool = false, useMultilingualTextEncoder: Bool = false, script: Script? = nil ) throws { /// Expect URL of each resource let urls = ResourceURLs(resourcesAt: baseURL) let textEncoder: TextEncoderModel if useMultilingualTextEncoder { guard #available(macOS 14.0, iOS 17.0, *) else { throw Error.unsupportedOSVersion } textEncoder = MultilingualTextEncoder( modelAt: urls.multilingualTextEncoderProjectionURL, configuration: config, script: script ?? .latin ) } else { let tokenizer = try BPETokenizer(mergesAt: urls.mergesURL, vocabularyAt: urls.vocabURL) textEncoder = TextEncoder(tokenizer: tokenizer, modelAt: urls.textEncoderURL, configuration: config) } // ControlNet model var controlNet: ControlNet? = nil let controlNetURLs = controlNetModelNames.map { model in let fileName = model + ".mlmodelc" return urls.controlNetDirURL.appending(path: fileName) } if !controlNetURLs.isEmpty { controlNet = ControlNet(modelAt: controlNetURLs, configuration: config) } // Unet model let unet: Unet let unetURL: URL, unetChunk1URL: URL, unetChunk2URL: URL // if ControlNet available, Unet supports additional inputs from ControlNet if controlNet == nil { unetURL = urls.unetURL unetChunk1URL = urls.unetChunk1URL unetChunk2URL = urls.unetChunk2URL } else { unetURL = urls.controlledUnetURL unetChunk1URL = urls.controlledUnetChunk1URL unetChunk2URL = urls.controlledUnetChunk2URL } if FileManager.default.fileExists(atPath: unetChunk1URL.path) && FileManager.default.fileExists(atPath: unetChunk2URL.path) { unet = Unet(chunksAt: [unetChunk1URL, unetChunk2URL], configuration: config) } else { unet = Unet(modelAt: unetURL, configuration: config) } // Image Decoder let decoder = Decoder(modelAt: urls.decoderURL, configuration: config) // Optional safety checker var safetyChecker: SafetyChecker? = nil if !disableSafety && FileManager.default.fileExists(atPath: urls.safetyCheckerURL.path) { safetyChecker = SafetyChecker(modelAt: urls.safetyCheckerURL, configuration: config) } // Optional Image Encoder let encoder: Encoder? if FileManager.default.fileExists(atPath: urls.encoderURL.path) { encoder = Encoder(modelAt: urls.encoderURL, configuration: config) } else { encoder = nil } // Construct pipeline if #available(macOS 14.0, iOS 17.0, *) { self.init( textEncoder: textEncoder, unet: unet, decoder: decoder, encoder: encoder, controlNet: controlNet, safetyChecker: safetyChecker, reduceMemory: reduceMemory, useMultilingualTextEncoder: useMultilingualTextEncoder, script: script ) } else { self.init( textEncoder: textEncoder, unet: unet, decoder: decoder, encoder: encoder, controlNet: controlNet, safetyChecker: safetyChecker, reduceMemory: reduceMemory ) } } }