#include "slang-ir-autodiff-rev.h"

#include "slang-ir-autodiff-cfg-norm.h"
#include "slang-ir-autodiff-fwd.h"
#include "slang-ir-clone.h"
#include "slang-ir-dce.h"
#include "slang-ir-dominators.h"
#include "slang-ir-eliminate-multilevel-break.h"
#include "slang-ir-init-local-var.h"
#include "slang-ir-inline.h"
#include "slang-ir-inst-pass-base.h"
#include "slang-ir-loop-unroll.h"
#include "slang-ir-redundancy-removal.h"
#include "slang-ir-single-return.h"
#include "slang-ir-ssa-simplification.h"
#include "slang-ir-util.h"

namespace Slang
{
IRFuncType* BackwardDiffTranscriberBase::differentiateFunctionTypeImpl(
    IRBuilder* builder,
    IRFuncType* funcType,
    IRInst* intermeidateType)
{
    List<IRType*> newParameterTypes;
    IRType* diffReturnType;

    for (UIndex i = 0; i < funcType->getParamCount(); i++)
    {
        auto origType = funcType->getParamType(i);
        auto paramType = transcribeParamTypeForPropagateFunc(builder, origType);
        if (paramType)
            newParameterTypes.add(paramType);
    }

    if (auto diffResultType = differentiateType(builder, funcType->getResultType()))
        newParameterTypes.add(diffResultType);

    if (intermeidateType)
    {
        newParameterTypes.add((IRType*)intermeidateType);
    }

    diffReturnType = builder->getVoidType();

    return builder->getFuncType(newParameterTypes, diffReturnType);
}

IRFuncType* BackwardDiffPrimalTranscriber::differentiateFunctionType(
    IRBuilder* builder,
    IRInst* func,
    IRFuncType* funcType)
{
    IRType* intermediateType =
        builder->getBackwardDiffIntermediateContextType(maybeFindOuterGeneric(func));
    if (auto outerGeneric = findOuterGeneric(builder->getInsertLoc().getParent()))
    {
        intermediateType =
            (IRType*)specializeWithGeneric(*builder, intermediateType, as<IRGeneric>(outerGeneric));
    }

    auto outType = builder->getOutType(intermediateType);
    List<IRType*> paramTypes;
    for (UInt i = 0; i < funcType->getParamCount(); i++)
    {
        auto origType = funcType->getParamType(i);
        auto primalType = transcribeParamTypeForPrimalFunc(builder, origType);
        paramTypes.add(primalType);
    }
    paramTypes.add(outType);
    IRFuncType* primalFuncType = builder->getFuncType(
        paramTypes,
        (IRType*)findOrTranscribePrimalInst(builder, funcType->getResultType()));
    return primalFuncType;
}

InstPair BackwardDiffPrimalTranscriber::transcribeFunc(
    IRBuilder* builder,
    IRFunc* primalFunc,
    IRFunc* diffFunc)
{
    // Don't need to do anything other than add a decoration in the original func to point to the
    // primal func. The body of the primal func will be generated by propagateTranscriber together
    // with propagate func.
    addTranscribedFuncDecoration(*builder, primalFunc, diffFunc);
    builder->addDecoration(diffFunc, kIROp_IgnoreSideEffectsDecoration);
    return InstPair(primalFunc, diffFunc);
}

static List<IRInst*> _defineFuncParams(IRBuilder* builder, IRFunc* func)
{
    auto propFuncType = cast<IRFuncType>(func->getFullType());
    List<IRInst*> params;
    for (UInt i = 0; i < propFuncType->getParamCount(); i++)
    {
        auto paramType = propFuncType->getParamType(i);
        auto param = builder->emitParam(paramType);
        params.add(param);
    }
    return params;
}

void BackwardDiffPropagateTranscriber::generateTrivialDiffFuncFromUserDefinedDerivative(
    IRBuilder* builder,
    IRFunc* originalFunc,
    IRFunc* diffPropFunc,
    IRUserDefinedBackwardDerivativeDecoration* udfDecor)
{
    // Create an empty struct type to use as the intermediate context type.
    auto originalGeneric = findOuterGeneric(originalFunc);
    builder->setInsertBefore(originalFunc);
    IRInst* emptyStruct = builder->createStructType();
    IRInst* emptyStructType = nullptr;
    auto emptyStructGeneric = hoistValueFromGeneric(*builder, emptyStruct, emptyStructType, false);
    builder->addBackwardDerivativeIntermediateTypeDecoration(originalFunc, emptyStructGeneric);

    IRInst* udf = udfDecor->getBackwardDerivativeFunc();
    builder->setInsertInto(diffPropFunc);
    builder->emitBlock();
    List<IRInst*> params = _defineFuncParams(builder, diffPropFunc);
    params.removeLast();
    IRInst* udfRefFromPropFunc = udf;
    if (auto specialize = as<IRSpecialize>(udf))
    {
        udf = specialize->getBase();
        auto propGeneric = findOuterGeneric(diffPropFunc);
        SLANG_RELEASE_ASSERT(propGeneric);
        udfRefFromPropFunc = maybeSpecializeWithGeneric(*builder, udf, propGeneric);
    }
    builder->emitCallInst(builder->getVoidType(), udfRefFromPropFunc, params);
    builder->emitReturn();

    // Copy other decorations from the original func to the generated primal func wrapper.
    copyOriginalDecorations(udf, diffPropFunc);

    // Now create the trivial primal function.
    auto existingDecor = originalFunc->findDecoration<IRBackwardDerivativePrimalDecoration>();
    if (!existingDecor)
    {
        // We haven't created a header for primal func yet, create it now.
        if (originalGeneric)
            builder->setInsertBefore(originalGeneric);
        else
            builder->setInsertBefore(originalFunc);

        autoDiffSharedContext->transcriberSet.primalTranscriber->transcribe(
            builder,
            originalGeneric ? originalGeneric : originalFunc);
        existingDecor = originalFunc->findDecoration<IRBackwardDerivativePrimalDecoration>();
    }
    SLANG_RELEASE_ASSERT(existingDecor);

    // Fill the primal func header with trivial call to original func.
    IRInst* existingPrimalFunc = existingDecor->getBackwardDerivativePrimalFunc();
    IRGeneric* existingPriamlFuncGeneric = nullptr;
    if (auto specialize = as<IRSpecialize>(existingPrimalFunc))
    {
        existingPriamlFuncGeneric = as<IRGeneric>(specialize->getBase());
        existingPrimalFunc = findGenericReturnVal(existingPriamlFuncGeneric);
    }
    builder->setInsertBefore(existingPrimalFunc);

    builder->setInsertInto(existingPrimalFunc);

    auto checkpointHint = udf->findDecoration<IRCheckpointHintDecoration>();
    if (!checkpointHint)
        checkpointHint = originalFunc->findDecoration<IRCheckpointHintDecoration>();
    if (checkpointHint)
        cloneCheckpointHint(
            builder,
            checkpointHint,
            cast<IRGlobalValueWithCode>(existingPrimalFunc));

    // Copy other decorations from the original func to the generated primal func wrapper.
    copyOriginalDecorations(udf, existingPrimalFunc);

    builder->emitBlock();
    params = _defineFuncParams(builder, as<IRFunc>(existingPrimalFunc));
    params.removeLast();

    // Unwrap any ref pairs. We need this special case for trivial funcs.
    for (Int i = 0; i < params.getCount(); i++)
    {
        if (as<IRDifferentialPtrPairType>(params[i]->getDataType()))
        {
            params[i] = builder->emitDifferentialPtrPairGetPrimal(params[i]);
        }
    }

    IRInst* originalFuncRefFromPrimalFunc = originalFunc;
    if (originalGeneric)
        originalFuncRefFromPrimalFunc =
            maybeSpecializeWithGeneric(*builder, originalGeneric, existingPriamlFuncGeneric);
    auto result = builder->emitCallInst(
        cast<IRFuncType>(existingPrimalFunc->getFullType())->getResultType(),
        originalFuncRefFromPrimalFunc,
        params);
    builder->emitReturn(result);
}

IRFuncType* BackwardDiffPropagateTranscriber::differentiateFunctionType(
    IRBuilder* builder,
    IRInst* func,
    IRFuncType* funcType)
{
    IRType* intermediateType = nullptr;
    if (auto outerGeneric = findOuterGeneric(builder->getInsertLoc().getParent()))
    {
        intermediateType =
            builder->getBackwardDiffIntermediateContextType(maybeFindOuterGeneric(func));
        intermediateType =
            (IRType*)specializeWithGeneric(*builder, intermediateType, as<IRGeneric>(outerGeneric));
    }
    else if (as<IRLookupWitnessMethod>(func))
    {
        intermediateType = nullptr;
    }
    else
    {
        intermediateType =
            builder->getBackwardDiffIntermediateContextType(maybeFindOuterGeneric(func));
    }
    return differentiateFunctionTypeImpl(builder, funcType, intermediateType);
}

IRFuncType* BackwardDiffTranscriber::differentiateFunctionType(
    IRBuilder* builder,
    IRInst* func,
    IRFuncType* funcType)
{
    SLANG_UNUSED(func);
    return differentiateFunctionTypeImpl(builder, funcType, nullptr);
}

InstPair BackwardDiffPropagateTranscriber::transcribeFunc(
    IRBuilder* builder,
    IRFunc* primalFunc,
    IRFunc* diffFunc)
{
    addTranscribedFuncDecoration(*builder, primalFunc, diffFunc);
    if (auto udf = primalFunc->findDecoration<IRUserDefinedBackwardDerivativeDecoration>())
    {
        generateTrivialDiffFuncFromUserDefinedDerivative(builder, primalFunc, diffFunc, udf);
    }
    else
    {
        transcribeFuncImpl(builder, primalFunc, diffFunc);
    }
    return InstPair(primalFunc, diffFunc);
}

InstPair BackwardDiffTranscriberBase::transcribeInstImpl(IRBuilder* builder, IRInst* origInst)
{
    switch (origInst->getOp())
    {
    case kIROp_Param:
        return transcribeParam(builder, as<IRParam>(origInst));

    case kIROp_Return:
        return transcribeReturn(builder, as<IRReturn>(origInst));

    case kIROp_LookupWitness:
        return transcribeLookupInterfaceMethod(builder, as<IRLookupWitnessMethod>(origInst));

    case kIROp_Specialize:
        return transcribeSpecialize(builder, as<IRSpecialize>(origInst));

    case kIROp_MakeTuple:
    case kIROp_FloatLit:
    case kIROp_IntLit:
    case kIROp_VoidLit:
    case kIROp_ExtractExistentialWitnessTable:
    case kIROp_ExtractExistentialType:
    case kIROp_ExtractExistentialValue:
    case kIROp_WrapExistential:
    case kIROp_MakeExistential:
    case kIROp_MakeExistentialWithRTTI:
        return transcribeNonDiffInst(builder, origInst);

    case kIROp_StructKey:
        return InstPair(origInst, nullptr);
    }

    return InstPair(nullptr, nullptr);
}

// Returns "dp<var-name>" to use as a name hint for parameters.
// If no primal name is available, returns a blank string.
//
String BackwardDiffTranscriberBase::makeDiffPairName(IRInst* origVar)
{
    if (auto namehintDecoration = origVar->findDecoration<IRNameHintDecoration>())
    {
        return ("dp" + String(namehintDecoration->getName()));
    }

    return String("");
}

static IRType* _getPrimalTypeFromNoDiffType(
    BackwardDiffTranscriberBase* transcriber,
    IRBuilder* builder,
    IRType* origType)
{
    IRType* valueType = origType;
    auto ptrType = as<IROutTypeBase>(valueType);
    if (ptrType)
        valueType = ptrType->getValueType();

    if (auto attrType = as<IRAttributedType>(valueType))
    {
        if (attrType->findAttr<IRNoDiffAttr>())
        {
            auto primalValueType =
                (IRType*)transcriber->findOrTranscribePrimalInst(builder, valueType);
            if (ptrType)
                return builder->getPtrType(ptrType->getOp(), primalValueType);
            return primalValueType;
        }
    }
    return nullptr;
}

IRType* BackwardDiffTranscriberBase::transcribeParamTypeForPrimalFunc(
    IRBuilder* builder,
    IRType* paramType)
{
    // If the param is marked as no_diff, return the primal type.
    if (auto primalNoDiffType = _getPrimalTypeFromNoDiffType(this, builder, paramType))
        return primalNoDiffType;

    auto primalType = (IRType*)findOrTranscribePrimalInst(builder, paramType);

    // Differentiable pointer types are treated as primal pairs, since they aren't involved in the
    // transposition process.
    //
    if (differentiableTypeConformanceContext.isDifferentiablePtrType(primalType))
    {
        auto diffPairType = tryGetDiffPairType(builder, primalType);
        SLANG_ASSERT(diffPairType);

        return diffPairType;
    }

    return primalType;
}

IRType* BackwardDiffTranscriberBase::transcribeParamTypeForPropagateFunc(
    IRBuilder* builder,
    IRType* paramType)
{
    if (auto outType = as<IROutType>(paramType))
    {
        auto valueType = outType->getValueType();
        auto diffValueType = differentiateType(builder, valueType);
        return diffValueType;
    }

    auto maybeConvertInOutTypeToValueType = [](IRType* type)
    {
        if (auto inoutType = as<IRInOutType>(type))
            return inoutType->getValueType();
        return type;
    };

    // If the param is marked as no_diff, return the primal type.
    if (auto primalNoDiffType = _getPrimalTypeFromNoDiffType(this, builder, paramType))
        return maybeConvertInOutTypeToValueType(primalNoDiffType);

    auto diffPairType = tryGetDiffPairType(builder, paramType);
    if (diffPairType)
    {
        if (!asRelevantPtrType(diffPairType) && !as<IRDifferentialPtrPairType>(diffPairType))
            return builder->getInOutType(diffPairType);
        return diffPairType;
    }
    auto primalType = (IRType*)findOrTranscribePrimalInst(builder, paramType);
    return maybeConvertInOutTypeToValueType(primalType);
}

// Create an empty func to represent the transcribed func of `origFunc`.
InstPair BackwardDiffTranscriberBase::transcribeFuncHeaderImpl(
    IRBuilder* inBuilder,
    IRFunc* origFunc)
{
    if (!isBackwardDifferentiableFunc(origFunc) &&
        !origFunc->findDecoration<IRTreatAsDifferentiableDecoration>())
        return InstPair(nullptr, nullptr);

    IRBuilder builder = *inBuilder;

    IRFunc* primalFunc = origFunc;

    maybeMigrateDifferentiableDictionaryFromDerivativeFunc(inBuilder, origFunc);
    differentiableTypeConformanceContext.setFunc(origFunc);

    auto diffFunc = builder.createFunc();

    SLANG_ASSERT(as<IRFuncType>(origFunc->getFullType()));
    builder.setInsertBefore(diffFunc);

    IRType* diffFuncType = this->differentiateFunctionType(
        &builder,
        origFunc,
        as<IRFuncType>(origFunc->getFullType()));
    diffFunc->setFullType(diffFuncType);

    if (origFunc->findDecoration<IRNameHintDecoration>())
    {
        auto newName = this->getTranscribedFuncName(&builder, origFunc);
        builder.addNameHintDecoration(diffFunc, newName);
    }
    addTranscribedFuncDecoration(builder, primalFunc, diffFunc);

    // Transfer checkpoint hint decorations
    copyCheckpointHints(&builder, origFunc, diffFunc);

    // Mark the generated derivative function itself as differentiable.
    builder.addBackwardDifferentiableDecoration(diffFunc);

    copyOriginalDecorations(origFunc, diffFunc);
    builder.addFloatingModeOverrideDecoration(diffFunc, FloatingPointMode::Fast);
    return InstPair(primalFunc, diffFunc);
}

void BackwardDiffTranscriberBase::addTranscribedFuncDecoration(
    IRBuilder& builder,
    IRFunc* origFunc,
    IRFunc* transcribedFunc)
{
    IRBuilder subBuilder = builder;
    if (auto outerGen = findOuterGeneric(transcribedFunc))
    {
        subBuilder.setInsertBefore(origFunc);
        auto specialized =
            specializeWithGeneric(subBuilder, outerGen, as<IRGeneric>(findOuterGeneric(origFunc)));
        addExistingDiffFuncDecor(&subBuilder, origFunc, specialized);
    }
    else
    {
        addExistingDiffFuncDecor(&subBuilder, origFunc, transcribedFunc);
    }
}

InstPair BackwardDiffTranscriberBase::transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc)
{
    InstPair result;

    // If we're transcribing a function as a 'value' (i.e. maybe embedded in a generic, keep the
    // insert location unchanges). If we're transcribing it as a declaration, we should
    // insert into the module.
    //
    auto origOuterGen = as<IRGeneric>(findOuterGeneric(origFunc));
    if (!origOuterGen || !(findInnerMostGenericReturnVal(origOuterGen) == origFunc))
    {
        // Dealing with a declaration.. insert into module scope.
        IRBuilder subBuilder = *inBuilder;
        subBuilder.setInsertInto(inBuilder->getModule());
        result = transcribeFuncHeaderImpl(&subBuilder, origFunc);
    }
    else
    {
        result = transcribeFuncHeaderImpl(inBuilder, origFunc);
    }

    FuncBodyTranscriptionTask task;
    task.originalFunc = as<IRFunc>(result.primal);
    task.resultFunc = as<IRFunc>(result.differential);
    task.type = diffTaskType;
    if (task.resultFunc)
    {
        autoDiffSharedContext->followUpFunctionsToTranscribe.add(task);
    }
    return result;
}

InstPair BackwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc)
{
    if (auto bwdDiffFunc = findExistingDiffFunc(origFunc))
        return InstPair(origFunc, bwdDiffFunc);

    auto header = transcribeFuncHeaderImpl(inBuilder, origFunc);
    if (!header.differential)
        return header;

    IRBuilder builder = *inBuilder;

    builder.setInsertInto(header.differential);
    builder.emitBlock();
    auto origFuncType = as<IRFuncType>(origFunc->getFullType());
    List<IRInst*> primalArgs, propagateArgs;
    List<IRType*> primalTypes, propagateTypes;
    IRType* primalResultType =
        transcribeParamTypeForPrimalFunc(&builder, origFuncType->getResultType());

    IRParam* currentParam = origFunc->getFirstParam();
    for (UInt i = 0; i < origFuncType->getParamCount(); i++)
    {
        IRBuilderSourceLocRAII sourceLocationScope(&builder, currentParam->sourceLoc);

        auto primalParamType =
            transcribeParamTypeForPrimalFunc(&builder, origFuncType->getParamType(i));
        auto propagateParamType =
            transcribeParamTypeForPropagateFunc(&builder, origFuncType->getParamType(i));
        if (propagateParamType)
        {
            auto param = builder.emitParam(propagateParamType);
            propagateTypes.add(propagateParamType);
            propagateArgs.add(param);

            // Fetch primal values to use as arguments in primal func call.
            IRInst* primalArg = param;
            if (!as<IROutType>(primalParamType) && !as<IRConstRefType>(primalParamType))
            {
                // As long as the primal parameter is not an out or constref type,
                // we need to fetch the primal value from the parameter.
                if (asRelevantPtrType(propagateParamType))
                {
                    primalArg = builder.emitLoad(param);
                }
                if (const auto diffPairType = as<IRDifferentialPairType>(primalArg->getDataType()))
                {
                    primalArg = builder.emitDifferentialPairGetPrimal(primalArg);
                }
            }
            if (auto primalParamPtrType = isMutablePointerType(primalParamType))
            {
                // If primal parameter is mutable, we need to pass in a temp var.
                auto tempVar = builder.emitVar(primalParamPtrType->getValueType());

                // If the parameter is not a pure 'out' param, we also need to setup the initial
                // value of the temp var, otherwise the temp var will be uninitialized which could
                // cause undefined behavior in the primal function.
                //
                if (!as<IROutType>(primalParamType))
                    builder.emitStore(tempVar, primalArg);

                primalArgs.add(tempVar);
            }
            else
            {
                primalArgs.add(primalArg);
            }
        }
        else
        {
            auto primalPtrType = asRelevantPtrType(primalParamType);
            SLANG_RELEASE_ASSERT(primalPtrType);
            auto primalValueType = primalPtrType->getValueType();
            auto var = builder.emitVar(primalValueType);
            primalArgs.add(var);
        }
        primalTypes.add(primalParamType);
        currentParam = currentParam->getNextParam();
    }

    // Add dOut argument to propagateArgs.
    auto diffResultType = differentiateType(&builder, origFunc->getResultType());
    if (diffResultType)
    {
        auto param = builder.emitParam(diffResultType);
        propagateArgs.add(param);
        propagateTypes.add(param->getFullType());
    }

    auto outerGeneric = findOuterGeneric(origFunc);
    IRType* intermediateType =
        builder.getBackwardDiffIntermediateContextType(maybeFindOuterGeneric(origFunc));
    IRInst* specializedOriginalFunc = origFunc;
    if (outerGeneric)
    {
        specializedOriginalFunc = maybeSpecializeWithGeneric(
            builder,
            outerGeneric,
            findOuterGeneric(header.differential));
        intermediateType = (IRType*)specializeWithGeneric(
            builder,
            intermediateType,
            as<IRGeneric>(findOuterGeneric(header.differential)));
    }

    auto intermediateVar = builder.emitVar(intermediateType);

    auto primalFuncType = builder.getFuncType(primalTypes, primalResultType);
    primalArgs.add(intermediateVar);
    primalTypes.add(builder.getOutType(intermediateType));
    auto primalFunc =
        builder.emitBackwardDifferentiatePrimalInst(primalFuncType, specializedOriginalFunc);
    builder.emitCallInst(primalResultType, primalFunc, primalArgs);

    propagateTypes.add(intermediateType);
    propagateArgs.add(builder.emitLoad(intermediateVar));
    auto propagateFuncType = builder.getFuncType(propagateTypes, builder.getVoidType());
    auto propagateFunc =
        builder.emitBackwardDifferentiatePropagateInst(propagateFuncType, specializedOriginalFunc);
    builder.emitCallInst(builder.getVoidType(), propagateFunc, propagateArgs);

    builder.emitReturn();

    addTranscribedFuncDecoration(builder, origFunc, cast<IRFunc>(header.differential));
    return header;
}

// Puts parameters into their own block.
void BackwardDiffTranscriberBase::makeParameterBlock(IRBuilder* inBuilder, IRFunc* func)
{
    IRBuilder builder = *inBuilder;

    auto firstBlock = func->getFirstBlock();
    IRParam* param = func->getFirstParam();

    builder.setInsertBefore(firstBlock);

    // Note: It looks like emitBlock() doesn't use the current
    // builder position, so we're going to manually move the new block
    // to before the existing block.
    auto paramBlock = builder.emitBlock();
    paramBlock->insertBefore(firstBlock);
    builder.setInsertInto(paramBlock);

    while (param)
    {
        IRParam* nextParam = param->getNextParam();

        // Move inst into the new parameter block.
        param->insertAtEnd(paramBlock);

        param = nextParam;
    }

    // Replace this block as the first block.
    firstBlock->replaceUsesWith(paramBlock);

    // Add terminator inst.
    builder.emitBranch(firstBlock);
}

SlangResult BackwardDiffTranscriberBase::prepareFuncForBackwardDiff(IRFunc* func)
{
    removeLinkageDecorations(func);

    performPreAutoDiffForceInlining(func);

    DifferentiableTypeConformanceContext diffTypeContext(autoDiffSharedContext);
    diffTypeContext.setFunc(func);

    auto returnCount = getReturnCount(func);
    if (returnCount > 1)
    {
        convertFuncToSingleReturnForm(func->getModule(), func);
    }
    else if (returnCount == 0)
    {
        // The function is ill-formed and never returns (such as having an infinite loop),
        // we can't possibly reverse-differentiate such functions, so we will diagnose it here.
        getSink()->diagnose(func->sourceLoc, Diagnostics::functionNeverReturnsFatal, func);
    }

    eliminateContinueBlocksInFunc(func->getModule(), func);

    eliminateMultiLevelBreakForFunc(func->getModule(), func);

    IRCFGNormalizationPass cfgPass = {this->getSink()};
    normalizeCFG(autoDiffSharedContext->moduleInst->getModule(), func, cfgPass);

    return SLANG_OK;
}

// Create a copy of originalFunc's forward derivative in the same generic context (if any) of
// `diffPropagateFunc`.
IRFunc* BackwardDiffTranscriberBase::generateNewForwardDerivativeForFunc(
    IRBuilder* builder,
    IRFunc* originalFunc,
    IRFunc* diffPropagateFunc)
{
    auto primalOuterParent = findOuterGeneric(originalFunc);
    if (!primalOuterParent)
        primalOuterParent = originalFunc;

    // Make a clone of original func so we won't modify the original.
    IRCloneEnv originalCloneEnv;
    primalOuterParent = cloneInst(&originalCloneEnv, builder, primalOuterParent);
    auto primalFunc = as<IRFunc>(getGenericReturnVal(primalOuterParent));

    // Strip any existing derivative decorations off the clone.
    stripDerivativeDecorations(primalFunc);
    eliminateDeadCode(primalOuterParent);

    // Perform required transformations and simplifications on the original func to make it
    // reversible.
    if (SLANG_FAILED(prepareFuncForBackwardDiff(primalFunc)))
        return diffPropagateFunc;

    // Forward transcribe the clone of the original func.
    ForwardDiffTranscriber& fwdTranscriber = *static_cast<ForwardDiffTranscriber*>(
        autoDiffSharedContext->transcriberSet.forwardTranscriber);
    auto oldCount = autoDiffSharedContext->followUpFunctionsToTranscribe.getCount();
    IRFunc* fwdDiffFunc =
        as<IRFunc>(getGenericReturnVal(fwdTranscriber.transcribe(builder, primalOuterParent)));
    fwdDiffFunc->sourceLoc = primalFunc->sourceLoc;

    SLANG_ASSERT(fwdDiffFunc);
    auto newCount = autoDiffSharedContext->followUpFunctionsToTranscribe.getCount();
    for (auto i = oldCount; i < newCount; i++)
    {
        auto pendingTask = autoDiffSharedContext->followUpFunctionsToTranscribe.getLast();
        autoDiffSharedContext->followUpFunctionsToTranscribe.removeLast();
        SLANG_RELEASE_ASSERT(pendingTask.type == FuncBodyTranscriptionTaskType::Forward);
        fwdTranscriber.transcribeFunc(builder, pendingTask.originalFunc, pendingTask.resultFunc);
    }

    // Remove the clone of original func.
    primalOuterParent->removeAndDeallocate();

    // Remove redundant loads since they interfere with transposition logic.
    eliminateRedundantLoadStore(fwdDiffFunc);

    // Migrate the new forward derivative function into the generic parent of `diffPropagateFunc`.
    if (auto fwdParentGeneric = as<IRGeneric>(findOuterGeneric(fwdDiffFunc)))
    {
        // Clone forward derivative func from its own generic into current generic parent.
        GenericChildrenMigrationContext migrationContext;
        auto diffOuterGeneric = as<IRGeneric>(findOuterGeneric(diffPropagateFunc));
        SLANG_RELEASE_ASSERT(diffOuterGeneric);

        migrationContext.init(fwdParentGeneric, diffOuterGeneric, diffPropagateFunc);
        auto inst = fwdParentGeneric->getFirstBlock()->getFirstOrdinaryInst();
        builder->setInsertBefore(diffPropagateFunc);
        while (inst)
        {
            auto next = inst->getNextInst();
            auto cloned = migrationContext.cloneInst(builder, inst);
            if (inst == fwdDiffFunc)
            {
                fwdDiffFunc = as<IRFunc>(cloned);
                break;
            }
            inst = next;
        }
        fwdParentGeneric->removeAndDeallocate();
    }

    return fwdDiffFunc;
}

InstPair BackwardDiffTranscriberBase::transcribeFuncParam(
    IRBuilder* builder,
    IRParam* origParam,
    IRInst* primalType)
{
    SLANG_UNUSED(primalType);

    SLANG_RELEASE_ASSERT(
        origParam->getParent() && origParam->getParent()->getParent() &&
        origParam->getParent()->getParent()->getOp() == kIROp_Generic);

    auto primalInst = maybeCloneForPrimalInst(builder, origParam);
    if (auto primalParam = as<IRParam>(primalInst))
    {
        SLANG_RELEASE_ASSERT(builder->getInsertLoc().getBlock());
        primalParam->removeFromParent();
        builder->getInsertLoc().getBlock()->addParam(primalParam);
    }
    return InstPair(primalInst, nullptr);
}

// Keep primal param replacement insts alive during DCE.
static void _lockPrimalParamReplacementInsts(
    IRBuilder* builder,
    ParameterBlockTransposeInfo& paramInfo)
{
    for (auto& kv : paramInfo.mapPrimalSpecificParamToReplacementInPropFunc)
        builder->addKeepAliveDecoration(kv.value);
}

// Remove [KeepAlive] decorations for primal param replacement insts.
static void _unlockPrimalParamReplacementInsts(ParameterBlockTransposeInfo& paramInfo)
{
    for (const auto& [_, value] : paramInfo.mapPrimalSpecificParamToReplacementInPropFunc)
        value->findDecoration<IRKeepAliveDecoration>()->removeAndDeallocate();
}

// Transcribe a function definition.
void BackwardDiffTranscriberBase::transcribeFuncImpl(
    IRBuilder* builder,
    IRFunc* primalFunc,
    IRFunc* diffPropagateFunc)
{
    SLANG_ASSERT(primalFunc);
    SLANG_ASSERT(diffPropagateFunc);
    // Reverse-mode transcription uses 4 separate steps:
    // TODO(sai): Fill in documentation.

    // Generate a temporary forward derivative function as an intermediate step.
    IRBuilder tempBuilder = *builder;
    if (auto outerGeneric = findOuterGeneric(diffPropagateFunc))
    {
        tempBuilder.setInsertBefore(outerGeneric);
    }
    else
    {
        tempBuilder.setInsertBefore(diffPropagateFunc);
    }

    auto fwdDiffFunc =
        generateNewForwardDerivativeForFunc(&tempBuilder, primalFunc, diffPropagateFunc);
    if (!fwdDiffFunc)
        return;

    bool isResultDifferentiable = as<IRDifferentialPairType>(fwdDiffFunc->getResultType());

    // Split first block into a paramter block.
    this->makeParameterBlock(&tempBuilder, as<IRFunc>(fwdDiffFunc));

    // This steps adds a decoration to instructions that are computing the differential.
    // TODO: This is disabled for now because fwd-mode already adds differential decorations
    // wherever need. We need to run this pass only for user-writted forward derivativecode.
    //
    // diffPropagationPass->propagateDiffInstDecoration(builder, fwdDiffFunc);

    diffUnzipPass->unzipDiffInsts(fwdDiffFunc);
    IRFunc* unzippedFwdDiffFunc = fwdDiffFunc;

    // Move blocks from `unzippedFwdDiffFunc` to the `diffPropagateFunc` shell.
    builder->setInsertInto(diffPropagateFunc->getParent());
    {
        List<IRBlock*> workList;
        for (auto block = unzippedFwdDiffFunc->getFirstBlock(); block;
             block = block->getNextBlock())
            workList.add(block);

        for (auto block : workList)
            block->insertAtEnd(diffPropagateFunc);
    }

    // Transpose the first block (parameter block)
    auto paramTransposeInfo = splitAndTransposeParameterBlock(
        builder,
        diffPropagateFunc,
        primalFunc->sourceLoc,
        isResultDifferentiable);

    // The insts we inserted in paramTransposeInfo.mapPrimalSpecificParamToReplacementInPropFunc
    // may be used by write back logic that we are going to insert later.
    // Before then we want to keep them alive.
    _lockPrimalParamReplacementInsts(builder, paramTransposeInfo);

    builder->setInsertInto(diffPropagateFunc);

    // Transpose differential blocks from unzippedFwdDiffFunc into diffFunc (with dOutParameter)
    // representing the derivative of the return value.
    DiffTransposePass::FuncTranspositionInfo transposeInfo = {paramTransposeInfo.dOutParam};
    diffTransposePass->transposeDiffBlocksInFunc(diffPropagateFunc, transposeInfo);

    // Apply checkpointing policy to legalize cross-scope uses of primal values
    // using either recompute or store strategies.
    auto primalsInfo = applyCheckpointPolicy(diffPropagateFunc);

    eliminateDeadCode(diffPropagateFunc);

    // Extracts the primal computations into its own func, turn all accesses to stored primal insts
    // into explicit intermediate data structure reads and writes.
    IRInst* intermediateType = nullptr;
    auto extractedPrimalFunc = diffUnzipPass->extractPrimalFunc(
        diffPropagateFunc,
        primalFunc,
        primalsInfo,
        paramTransposeInfo,
        intermediateType);

    // At this point the unzipped func is just an empty shell
    // and we can simply remove it.
    unzippedFwdDiffFunc->removeAndDeallocate();

    // Write back derivatives to inout parameters.
    writeBackDerivativeToInOutParams(paramTransposeInfo, diffPropagateFunc);

    // Remove primalFunc specific params.
    List<IRInst*> paramsToRemove;
    for (auto param : diffPropagateFunc->getParams())
    {
        if (!paramTransposeInfo.propagateFuncParams.contains(param))
            paramsToRemove.add(param);
    }
    for (auto param : paramsToRemove)
    {
        if (param->hasUses())
        {
            IRInst* replacement = nullptr;
            paramTransposeInfo.mapPrimalSpecificParamToReplacementInPropFunc.tryGetValue(
                param,
                replacement);
            SLANG_RELEASE_ASSERT(replacement);
            param->replaceUsesWith(replacement);
        }
        param->removeAndDeallocate();
    }

    _unlockPrimalParamReplacementInsts(paramTransposeInfo);

    // If primal function is nested in a generic, we want to create separate generics for all the
    // associated things we have just created.
    auto primalOuterGeneric = findOuterGeneric(primalFunc);
    IRInst* specializedFunc = nullptr;
    auto intermediateTypeGeneric =
        hoistValueFromGeneric(*builder, intermediateType, specializedFunc, true);
    builder->setInsertBefore(primalFunc);
    builder->addBackwardDerivativeIntermediateTypeDecoration(primalFunc, intermediateTypeGeneric);

    auto primalFuncGeneric =
        hoistValueFromGeneric(*builder, extractedPrimalFunc, specializedFunc, true);
    builder->setInsertBefore(primalFunc);

    // Copy over checkpoint preference hints.
    {
        auto diffPrimalFunc = getResolvedInstForDecorations(primalFuncGeneric, true);
        auto checkpointHint = primalFunc->findDecoration<IRCheckpointHintDecoration>();
        if (checkpointHint)
            builder->addDecoration(diffPrimalFunc, checkpointHint->getOp());
    }

    if (auto existingDecor = primalFunc->findDecoration<IRBackwardDerivativePrimalDecoration>())
    {
        // If we already created a header for primal func, move the body into the existing primal
        // func header.
        auto existingPrimalHeader = existingDecor->getBackwardDerivativePrimalFunc();
        if (auto spec = as<IRSpecialize>(existingPrimalHeader))
            existingPrimalHeader = spec->getBase();
        moveInstChildren(existingPrimalHeader, primalFuncGeneric);
        primalFuncGeneric->replaceUsesWith(existingPrimalHeader);
        primalFuncGeneric->removeAndDeallocate();
        primalFuncGeneric = existingPrimalHeader;
    }
    else
    {
        auto specializedBackwardPrimalFunc =
            maybeSpecializeWithGeneric(*builder, primalFuncGeneric, primalOuterGeneric);
        builder->addBackwardDerivativePrimalDecoration(primalFunc, specializedBackwardPrimalFunc);
    }

    initializeLocalVariables(
        builder->getModule(),
        as<IRGlobalValueWithCode>(getGenericReturnVal(primalFuncGeneric)));
    initializeLocalVariables(builder->getModule(), diffPropagateFunc);

    stripTempDecorations(diffPropagateFunc);

    sortBlocksInFunc(diffPropagateFunc);
    sortBlocksInFunc(primalFunc);
}

ParameterBlockTransposeInfo BackwardDiffTranscriberBase::splitAndTransposeParameterBlock(
    IRBuilder* builder,
    IRFunc* diffFunc,
    SourceLoc primalLoc,
    bool isResultDifferentiable)
{
    // This method splits transposes the all the parameters for both the primal and propagate
    // computation. At the end of this method, the parameter block will contain a combination of
    // parameters for both the to-be-primal function and to-be-propagate function. We use
    // ParameterBlockTransposeInfo::primalFuncParams and
    // ParameterBlockTransposeInfo::propagateFuncParams to track which parameters are dedicated to
    // the future primal or propagate func. A later step will then split the parameters out to each
    // new function.

    ParameterBlockTransposeInfo result;

    // First, we initialize the IR builders and locate the import code insertion points that will
    // be used for the rest of this method.

    IRBlock* fwdDiffParameterBlock = diffFunc->getFirstBlock();

    // Find the 'next' block using the terminator inst of the parameter block.
    auto fwdParamBlockBranch = as<IRUnconditionalBranch>(fwdDiffParameterBlock->getTerminator());
    // We create a new block after parameter block to hold insts that translates from transposed
    // parameters into something that the rest of the function can use.
    IRBuilder::insertBlockAlongEdge(diffFunc->getModule(), IREdge(&fwdParamBlockBranch->block));
    auto paramPreludeBlock = fwdParamBlockBranch->getTargetBlock();

    auto nextBlockBuilder = *builder;
    nextBlockBuilder.setInsertBefore(paramPreludeBlock->getFirstOrdinaryInst());

    SourceLoc returnLoc;
    IRBlock* firstDiffBlock = nullptr;
    for (auto block : diffFunc->getBlocks())
    {
        if (isDifferentialInst(block))
        {
            firstDiffBlock = block;
            break;
        }

        auto terminator = block->getTerminator();
        if (as<IRReturn>(terminator))
        {
            returnLoc = terminator->sourceLoc;
            break;
        }
    }

    SLANG_RELEASE_ASSERT(firstDiffBlock);

    auto diffBuilder = *builder;
    diffBuilder.setInsertBefore(firstDiffBlock->getFirstOrdinaryInst());

    builder->setInsertBefore(fwdParamBlockBranch);

    // Collect all the original parameters.
    List<IRParam*> fwdParams;
    for (auto param : diffFunc->getParams())
        fwdParams.add(param);

    // Maintain a set for insts pending removal.
    OrderedHashSet<IRInst*> instsToRemove;

    // Now we begin the actual processing.
    // The first step is to transcribe all the existing parameters from the original function.
    // There are many cases to handle, including different combinations of parameter directions and
    // whether or not the parameter is differentiable.
    // To normalize the process for all these cases, we determine the following actions for each
    // parameter:
    // 1. Should this original parameter be translated to a parameter in the primal func and the
    // propagate func?
    //    if so, we emit a param inst representing the final parameter for that func. If the
    //    parameter should be mapped to both the primal func and the propagate func, we will emit
    //    two separate params with their final type.
    // 2. If this parameter has a corresponding primal func parameter, we replace all uses of the
    // original
    //    parameter in the primal computation code to the new primal parameter. If any
    //    initialization logic is needed to convert the type of the new primal parameter to what the
    //    code was expecting, we insert that code in the first block.
    // 3. If this parameter has a correponding propagate func parameter, we replace all uses of the
    // original parameter
    //    in the diff computation code to the new propagate parameter. We insert necessary
    //    initialization diff block or the first block depending on whether we want that logic go
    //    through the transposition pass. We may need to replace the uses to different
    //    values/variables depending on whether that use is a read or write.
    // 4. If the parameter has both corresponding primal and propagate parameters, we also need to
    // consider
    //    how the future propagate function access the primal parameter. We will insert necessary
    //    preparation code that constructs temp vars or values to replace the primal parameter after
    //    we remove it from the propagate func.
    // Base on above discussion, we need to compute the following values for each parameter:
    // - diffRefReplacement. What should all read(load) references to this parameter from
    // differential code be replaced to.
    // - diffRefWriteReplacement. What should all write references to this parameter from
    // differential code be replaced to.
    // - primalRefReplacement. What should all references to this parameter from primal code be
    // replaced to.
    // - mapPrimalSpecificParamToReplacementInPropFunc[param]. What should all references to this
    // parameter
    //      from the primal compuation logic in the future propagate function be replaced to.
    for (auto fwdParam : fwdParams)
    {
        IRBuilderSourceLocRAII sourceLocationScope(builder, fwdParam->sourceLoc);

        // Define the replacement insts that we are going to fill in for each case.
        IRInst* diffRefReplacement = nullptr;
        IRInst* primalRefReplacement = nullptr;
        IRInst* diffWriteRefReplacement = nullptr;

        // Common logic that computes all the important types we care about.
        IRDifferentialPairType* diffPairType = as<IRDifferentialPairType>(fwdParam->getDataType());
        auto inoutType = as<IRInOutType>(fwdParam->getDataType());
        auto outType = as<IROutType>(fwdParam->getDataType());
        if (inoutType)
            diffPairType = as<IRDifferentialPairType>(inoutType->getValueType());
        else if (outType)
            diffPairType = as<IRDifferentialPairType>(outType->getValueType());
        IRType* primalType = nullptr;
        IRType* diffType = nullptr;
        if (diffPairType)
        {
            primalType = diffPairType->getValueType();
            diffType = (IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType(
                builder,
                diffPairType);
        }

        // Now we handle each combination of parameter direction x differentiability.
        if (outType)
        {
            // Case 1: out parameters.
            // Out parameters need to be handled differently whether or not it is differentiable,
            // since the propagate function will not have a corresponding output.
            if (diffPairType)
            {
                // Create dOut param.
                auto diffParam = builder->emitParam(diffType);
                copyNameHintAndDebugDecorations(diffParam, fwdParam);
                result.propagateFuncParams.add(diffParam);
                primalRefReplacement = builder->emitParam(builder->getOutType(primalType));
                copyNameHintAndDebugDecorations(primalRefReplacement, fwdParam);

                // Create a local var for read access in pre-transpose code.
                // This will the var from which we will fetch the final resulting derivative
                // after transposition.
                auto tempVar = nextBlockBuilder.emitVar(diffType);
                copyNameHintAndDebugDecorations(tempVar, fwdParam);
                result.propagateFuncSpecificPrimalInsts.add(tempVar);

                // Initialize the var with input diff param at start.
                // Note that we insert the store in the primal block so it won't get transposed.
                auto storeInst = nextBlockBuilder.emitStore(tempVar, diffParam);
                nextBlockBuilder.markInstAsDifferential(storeInst, primalType);
                // Since this store inst is specific to propagate function, we track it in a
                // set so we can remove it when we generate the primal func.
                result.propagateFuncSpecificPrimalInsts.add(storeInst);

                diffWriteRefReplacement = tempVar;
                diffRefReplacement = tempVar;
            }
            else
            {
                primalRefReplacement = builder->emitParam(outType);
                copyNameHintAndDebugDecorations(primalRefReplacement, fwdParam);
            }
            result.primalFuncParams.add(primalRefReplacement);

            // Create a local var for the out param for the primal part of the prop func.
            auto tempPrimalVar = nextBlockBuilder.emitVar(outType->getValueType());
            copyNameHintAndDebugDecorations(tempPrimalVar, fwdParam);
            result.mapPrimalSpecificParamToReplacementInPropFunc[primalRefReplacement] =
                tempPrimalVar;

            instsToRemove.add(fwdParam);
        }
        else if (!isRelevantDifferentialPair(fwdParam->getDataType()))
        {
            if (inoutType)
            {
                // Case 2: non differentiable inout parameter.
                // They should become an inout parameter in primal func, but an in parameter in
                // bwd func.
                fwdParam->removeFromParent();
                fwdDiffParameterBlock->addParam(fwdParam);
                result.primalFuncParams.add(fwdParam);

                primalRefReplacement = fwdParam;

                // Create an in param for the prop func.
                auto propParam = builder->emitParam(inoutType->getValueType());
                copyNameHintAndDebugDecorations(propParam, fwdParam);
                result.propagateFuncParams.add(propParam);

                // Create a local var for the out param for the primal part of the prop func.
                auto tempPrimalVar = nextBlockBuilder.emitVar(inoutType->getValueType());
                copyNameHintAndDebugDecorations(tempPrimalVar, fwdParam);

                result.propagateFuncSpecificPrimalInsts.add(tempPrimalVar);
                auto storeInst = nextBlockBuilder.emitStore(tempPrimalVar, propParam);
                result.propagateFuncSpecificPrimalInsts.add(storeInst);
                result.mapPrimalSpecificParamToReplacementInPropFunc[primalRefReplacement] =
                    tempPrimalVar;
            }
            else
            {
                // Case 3: non differentiable, non output parameters.
                // If parameter is not an out param and has nothing to do with differentiation,
                // simply move the parameter to the end.
                //
                fwdParam->removeFromParent();
                fwdDiffParameterBlock->addParam(fwdParam);
                result.primalFuncParams.add(fwdParam);
                result.propagateFuncParams.add(fwdParam);
                continue;
            }
        }
        else if (!inoutType)
        {
            // Case 4: `in` differentiable parameters.

            SLANG_RELEASE_ASSERT(diffPairType);

            // Create inout version.
            auto inoutDiffPairType = builder->getInOutType(diffPairType);
            primalRefReplacement = builder->emitParam(primalType);
            copyNameHintAndDebugDecorations(primalRefReplacement, fwdParam);

            result.primalFuncParams.add(primalRefReplacement);
            auto propParam = builder->emitParam(inoutDiffPairType);
            copyNameHintAndDebugDecorations(propParam, fwdParam);
            result.propagateFuncParams.add(propParam);

            // A reference to this parameter from the diff blocks should be replaced with a load
            // of the differential component of the pair.
            auto newParamLoad = diffBuilder.emitLoad(propParam);
            diffBuilder.markInstAsDifferential(newParamLoad, primalType);
            result.propagateFuncSpecificPrimalInsts.add(newParamLoad);

            diffRefReplacement =
                diffBuilder.emitDifferentialPairGetDifferential(diffType, newParamLoad);
            diffBuilder.markInstAsDifferential(diffRefReplacement, primalType);
            result.propagateFuncSpecificPrimalInsts.add(diffRefReplacement);

            // Load the primal component from the prop param and use it as replacement for the
            // primal param in the primal part of the prop func.
            // Since these are logic specific to propagate function, we will add them to the
            // `propagateFuncSpecificPrimalInsts` set so we can remove them when we generate the
            // primal func.
            auto primalReplacementLoad = nextBlockBuilder.emitLoad(propParam);
            result.propagateFuncSpecificPrimalInsts.add(primalReplacementLoad);
            auto primalVal = nextBlockBuilder.emitDifferentialPairGetPrimal(primalReplacementLoad);
            result.propagateFuncSpecificPrimalInsts.add(primalVal);
            result.mapPrimalSpecificParamToReplacementInPropFunc[primalRefReplacement] = primalVal;

            instsToRemove.add(fwdParam);
        }
        else
        {
            // Case 5: `inout` differentiable parameters.
            SLANG_ASSERT(inoutType && diffPairType);

            // Process differentiable inout parameters.
            auto primalParam = builder->emitParam(builder->getInOutType(primalType));
            copyNameHintAndDebugDecorations(primalParam, fwdParam);
            result.primalFuncParams.add(primalParam);

            auto diffParam = builder->emitParam(inoutType);
            copyNameHintAndDebugDecorations(diffParam, fwdParam);
            result.propagateFuncParams.add(diffParam);

            // Primal references to this param is the new primal param.
            primalRefReplacement = primalParam;

            // Diff references to this param should be replaced with one local temp var
            // for read and one separate temp var for write.

            // Load the inital diff value.
            auto loadedParam = nextBlockBuilder.emitLoad(diffParam);
            result.propagateFuncSpecificPrimalInsts.add(loadedParam);

            auto initDiff =
                nextBlockBuilder.emitDifferentialPairGetDifferential(diffType, loadedParam);
            result.propagateFuncSpecificPrimalInsts.add(initDiff);

            // Create a local var for diff read access.
            auto diffVar = nextBlockBuilder.emitVar(diffType);
            copyNameHintAndDebugDecorations(diffVar, fwdParam);
            result.propagateFuncSpecificPrimalInsts.add(diffVar);
            diffRefReplacement = diffVar;

            // Clear the diff read var to zero at start of the function.
            auto dzero = getDifferentialZeroOfType(&nextBlockBuilder, primalType);
            result.propagateFuncSpecificPrimalInsts.add(dzero);
            auto initDiffStore = nextBlockBuilder.emitStore(diffVar, dzero);
            result.propagateFuncSpecificPrimalInsts.add(initDiffStore);

            // Create a local var for diff write access.
            auto diffWriteVar = nextBlockBuilder.emitVar(diffType);
            result.propagateFuncSpecificPrimalInsts.add(diffWriteVar);
            copyNameHintAndDebugDecorations(diffWriteVar, fwdParam);

            // Initialize write var to 0.
            auto writeStore = nextBlockBuilder.emitStore(diffWriteVar, initDiff);
            result.propagateFuncSpecificPrimalInsts.add(writeStore);

            diffWriteRefReplacement = diffWriteVar;

            // Create a local var for the primal logic in the propagate func.
            auto primalVar = nextBlockBuilder.emitVar(primalType);
            copyNameHintAndDebugDecorations(primalVar, fwdParam);

            result.propagateFuncSpecificPrimalInsts.add(primalVar);
            auto initPrimalVal = nextBlockBuilder.emitDifferentialPairGetPrimal(loadedParam);
            result.propagateFuncSpecificPrimalInsts.add(initPrimalVal);
            auto storeInst = nextBlockBuilder.emitStore(primalVar, initPrimalVal);
            result.propagateFuncSpecificPrimalInsts.add(storeInst);
            result.mapPrimalSpecificParamToReplacementInPropFunc[primalParam] = primalVar;
            result.outDiffWritebacks[diffParam] = InstPair(initPrimalVal, diffVar);

            instsToRemove.add(fwdParam);
        }

        // We have emitted all the new parameters and computed the replacements for the original
        // parameter. Now we perform that replacement.
        List<IRUse*> uses;
        for (auto use = fwdParam->firstUse; use; use = use->nextUse)
            uses.add(use);
        for (auto use : uses)
        {
            if (auto primalRef = as<IRPrimalParamRef>(use->getUser()))
            {
                SLANG_RELEASE_ASSERT(primalRefReplacement);
                primalRef->replaceUsesWith(primalRefReplacement);
                instsToRemove.add(primalRef);
            }
            else if (auto getPrimal = as<IRDifferentialPairGetPrimal>(use->getUser()))
            {
                SLANG_RELEASE_ASSERT(primalRefReplacement);
                getPrimal->replaceUsesWith(primalRefReplacement);
                instsToRemove.add(getPrimal);
            }
            else if (auto propagateRef = as<IRDiffParamRef>(use->getUser()))
            {
                SLANG_RELEASE_ASSERT(diffRefReplacement);
                auto refUse = propagateRef->firstUse;
                while (refUse)
                {
                    auto nextUse = refUse->nextUse;
                    // Is this use the dest operand of a store inst?
                    // If so, replace it with writeRefReplacement, otherwise, refReplacement.
                    if (refUse->getUser()->getOp() == kIROp_Store &&
                        refUse == refUse->getUser()->getOperands())
                    {
                        SLANG_RELEASE_ASSERT(diffWriteRefReplacement);
                        refUse->set(diffWriteRefReplacement);
                    }
                    else
                    {
                        refUse->set(diffRefReplacement);
                    }
                    refUse = nextUse;
                }
                instsToRemove.add(propagateRef);
            }
            else if (auto getDiff = as<IRDifferentialPairGetDifferential>(use->getUser()))
            {
                SLANG_RELEASE_ASSERT(diffRefReplacement);
                getDiff->replaceUsesWith(diffRefReplacement);
                instsToRemove.add(getDiff);
            }
            else
            {
                // If the user is something else, it'd better be a non relevant parameter.
                if (diffRefReplacement || diffWriteRefReplacement)
                    SLANG_UNEXPECTED("unknown use of parameter.");
                use->set(primalRefReplacement);
            }
        }
    }

    // Actually remove all the insts that we decided to remove in the process.
    for (auto inst : instsToRemove)
    {
        inst->removeAndDeallocate();
    }


    // The next step is to insert new parameters that is not related to any existing parameters.
    //
    // If the return type of the original function is differentiable,
    // add a parameter for 'derivative of the output' (d_out).
    // The type is the second last parameter type of the function.
    //
    auto paramCount = as<IRFuncType>(diffFunc->getDataType())->getParamCount();
    IRParam* dOutParam = nullptr;
    if (isResultDifferentiable)
    {
        auto dOutParamType = as<IRFuncType>(diffFunc->getDataType())->getParamType(paramCount - 2);

        SLANG_ASSERT(dOutParamType);

        dOutParam = builder->emitParam(dOutParamType);
        dOutParam->sourceLoc = returnLoc;
        builder->addNameHintDecoration(dOutParam, UnownedStringSlice("_s_dOut"));
        result.propagateFuncParams.add(dOutParam);
    }

    // Add a parameter for intermediate val.
    auto ctxParam =
        builder->emitParam(as<IRFuncType>(diffFunc->getDataType())->getParamType(paramCount - 1));
    builder->addNameHintDecoration(ctxParam, UnownedStringSlice("_s_diff_ctx"));
    builder->addDecoration(ctxParam, kIROp_PrimalContextDecoration);
    result.primalFuncParams.add(ctxParam);
    result.propagateFuncParams.add(ctxParam);
    result.dOutParam = dOutParam;

    diffFunc->sourceLoc = primalLoc;
    ctxParam->sourceLoc = primalLoc;

    return result;
}

void BackwardDiffTranscriberBase::writeBackDerivativeToInOutParams(
    ParameterBlockTransposeInfo& info,
    IRFunc* diffFunc)
{
    IRInst* returnInst = nullptr;
    for (auto block : diffFunc->getBlocks())
    {
        for (auto inst : block->getChildren())
        {
            if (inst->getOp() == kIROp_Return)
            {
                returnInst = inst;
                break;
            }
        }
    }
    SLANG_RELEASE_ASSERT(returnInst);

    IRBuilder builder(autoDiffSharedContext->moduleInst);
    builder.setInsertBefore(returnInst);
    for (auto& wb : info.outDiffWritebacks)
    {
        auto dest = wb.key;
        auto srcPrimalVal = wb.value.primal;
        auto srcDiffAddr = wb.value.differential;
        auto srcDiffVal = builder.emitLoad(srcDiffAddr);
        auto destVal = builder.emitMakeDifferentialPair(
            as<IRPtrTypeBase>(dest->getFullType())->getValueType(),
            srcPrimalVal,
            srcDiffVal);
        builder.emitStore(dest, destVal);
    }
}

InstPair BackwardDiffTranscriberBase::transcribeSpecialize(
    IRBuilder* builder,
    IRSpecialize* origSpecialize)
{
    auto primalBase = findOrTranscribePrimalInst(builder, origSpecialize->getBase());
    List<IRInst*> primalArgs;
    for (UInt i = 0; i < origSpecialize->getArgCount(); i++)
    {
        primalArgs.add(findOrTranscribePrimalInst(builder, origSpecialize->getArg(i)));
    }
    auto primalType = findOrTranscribePrimalInst(builder, origSpecialize->getFullType());
    auto primalSpecialize = (IRSpecialize*)builder->emitSpecializeInst(
        (IRType*)primalType,
        primalBase,
        primalArgs.getCount(),
        primalArgs.getBuffer());

    if (auto diffBase = instMapD.tryGetValue(origSpecialize->getBase()))
    {
        List<IRInst*> args;
        for (UInt i = 0; i < primalSpecialize->getArgCount(); i++)
        {
            args.add(primalSpecialize->getArg(i));
        }
        auto diffSpecialize = builder->emitSpecializeInst(
            builder->getTypeKind(),
            *diffBase,
            args.getCount(),
            args.getBuffer());
        return InstPair(primalSpecialize, diffSpecialize);
    }

    auto genericInnerVal = findInnerMostGenericReturnVal(as<IRGeneric>(origSpecialize->getBase()));
    // Look for an IRBackwardDerivativeDecoration on the specialize inst.
    // (Normally, this would be on the inner IRFunc, but in this case only the JVP func
    // can be specialized, so we put a decoration on the IRSpecialize)
    //
    if (auto derivativeFunc = findExistingDiffFunc(origSpecialize))
    {
        // Make sure this isn't itself a specialize .
        SLANG_RELEASE_ASSERT(!as<IRSpecialize>(derivativeFunc));

        return InstPair(primalSpecialize, derivativeFunc);
    }
    else if (auto diffBase = findExistingDiffFunc(genericInnerVal))
    {
        List<IRInst*> args;
        for (UInt i = 0; i < primalSpecialize->getArgCount(); i++)
        {
            args.add(primalSpecialize->getArg(i));
        }

        // A `BackwardDerivative` decoration on an inner func of a generic should always be a
        // `specialize`.
        auto diffBaseSpecialize = as<IRSpecialize>(diffBase);
        SLANG_RELEASE_ASSERT(diffBaseSpecialize);

        // Note: this assumes that the generic arguments to specialize the derivative is the same as
        // the generic args to specialize the primal function. This is true for all of our core
        // module functions, but we may need to rely on more general substitution logic here.
        auto diffSpecialize = builder->emitSpecializeInst(
            builder->getTypeKind(),
            diffBaseSpecialize->getBase(),
            args.getCount(),
            args.getBuffer());

        return InstPair(primalSpecialize, diffSpecialize);
    }
    else if (isBackwardDifferentiableFunc(genericInnerVal) || as<IRFuncType>(genericInnerVal))
    {
        List<IRInst*> args;
        for (UInt i = 0; i < primalSpecialize->getArgCount(); i++)
        {
            args.add(primalSpecialize->getArg(i));
        }
        auto diffCallee = findOrTranscribeDiffInst(builder, origSpecialize->getBase());
        auto diffSpecialize = builder->emitSpecializeInst(
            builder->getTypeKind(),
            diffCallee,
            args.getCount(),
            args.getBuffer());
        return InstPair(primalSpecialize, diffSpecialize);
    }
    else
    {
        return InstPair(primalSpecialize, nullptr);
    }
}
} // namespace Slang
