Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/Accounts/AssemblyLoading/ConditionalAssemblyProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public static void Initialize(string rootPath, IConditionalAssemblyContext conte
CreateAssembly("net462", "System.Diagnostics.DiagnosticSource", "8.0.0.1").WithWindowsPowerShell(),
CreateAssembly("net462", "System.Text.Encodings.Web", "8.0.0.0").WithWindowsPowerShell(),
CreateAssembly("net47", "System.Security.Cryptography.Cng", "4.3.0.0").WithWindowsPowerShell(),
CreateAssembly("netstandard2.0", "Azure.Core", "1.47.3.0"),
CreateAssembly("netstandard2.0", "Azure.Core", "1.50.0.0"),
CreateAssembly("netstandard2.0", "Azure.Identity.Broker", "1.1.0.0"),
CreateAssembly("netstandard2.0", "Azure.Identity", "1.13.0.0"),
CreateAssembly("netstandard2.0", "Microsoft.Bcl.AsyncInterfaces", "8.0.0.0"),
Expand All @@ -61,7 +61,7 @@ public static void Initialize(string rootPath, IConditionalAssemblyContext conte
CreateAssembly("netstandard2.0", "Microsoft.Identity.Client.NativeInterop", "0.16.2.0"),
CreateAssembly("netstandard2.0", "Microsoft.IdentityModel.Abstractions", "6.35.0.0"),
CreateAssembly("netstandard2.0", "System.Buffers", "4.0.3.0").WithWindowsPowerShell(),
CreateAssembly("netstandard2.0", "System.ClientModel", "1.6.1.0"),
CreateAssembly("netstandard2.0", "System.ClientModel", "1.8.0.0"),
CreateAssembly("netstandard2.0", "System.Memory.Data", "8.0.0.1"),
CreateAssembly("netstandard2.0", "System.Memory", "4.0.1.2").WithWindowsPowerShell(),
CreateAssembly("netstandard2.0", "System.Net.Http.WinHttpHandler", "4.0.4.0").WithWindowsPowerShell(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public void InitCommand()
{
command = new StorageCloudBlobCmdletBase(BlobMock)
{
Context = new AzureStorageContext(CloudStorageAccount.DevelopmentStorageAccount),
Context = new AzureStorageContext(CloudStorageAccount.DevelopmentStorageAccount, null, null, null),
CommandRuntime = MockCmdRunTime
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,10 @@ public void GetStorageAccountByConnectionStringAndSasToken()
string endpoint = "http://storageaccountname.blob.core.windows.net";
string connectionString = String.Format("BlobEndpoint={0};QueueEndpoint={0};TableEndpoint={0};SharedAccessSignature={1}", endpoint, sasToken);
CloudStorageAccount account = command.GetStorageAccountByConnectionString(connectionString);
AzureStorageContext context = new AzureStorageContext(account);
AzureStorageContext context = new AzureStorageContext(account, null, null, null);
connectionString = String.Format("BlobEndpoint={0};SharedAccessSignature={1}", endpoint, sasToken);
account = command.GetStorageAccountByConnectionString(connectionString);
context = new AzureStorageContext(account);
context = new AzureStorageContext(account, null, null, null);
}
}
}
6 changes: 3 additions & 3 deletions src/Storage/Storage.Test/Common/StorageCloudCmdletBaseTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public void CleanCommand()
public void GetCloudStorageAccountFromContextTest()
{
CloudStorageAccount account = CloudStorageAccount.DevelopmentStorageAccount;
command.Context = new AzureStorageContext(account);
command.Context = new AzureStorageContext(account, null, null, null);
Assert.AreEqual(command.Context, command.GetCmdletStorageContext());
}

Expand All @@ -73,7 +73,7 @@ public void WriteObjectWithStorageContextWithNullContextTest()
public void WriteObjectWithStorageContextWithContextTest()
{
CloudStorageAccount account = CloudStorageAccount.DevelopmentStorageAccount;
command.Context = new AzureStorageContext(account);
command.Context = new AzureStorageContext(account, null, null, null);

AzureStorageBase item = new AzureStorageBase();
command.WriteObjectWithStorageContext(item);
Expand Down Expand Up @@ -105,7 +105,7 @@ public void WriteObjectWithStorageContextWihtEnumerableList()
public void ShouldInitServiceChannelTest()
{
CloudStorageAccount account = CloudStorageAccount.DevelopmentStorageAccount;
command.Context = new AzureStorageContext(account);
command.Context = new AzureStorageContext(account, null, null, null);
string toss;
Assert.IsFalse(command.TryGetStorageAccount(command.SMProfile, out toss));
}
Expand Down
5 changes: 5 additions & 0 deletions src/Storage/Storage.Test/Service/MockStorageBlobManagement.cs
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,11 @@ public BlobServiceClient GetBlobServiceClient(BlobClientOptions options = null)
throw new NotImplementedException();
}

public bool IsSasWithOAuthCredential()
{
throw new NotImplementedException();
}

/// <summary>
/// The storage context
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ public static IStorageContext GetStorageContext(this IStorageService service)
{
return new AzureStorageContext(new CloudStorageAccount(new StorageCredentials(service.Name, service.AuthenticationKeys.First()),
new StorageUri(service.BlobEndpoint), new StorageUri(service.QueueEndpoint),
new StorageUri(service.TableEndpoint), new StorageUri(service.FileEndpoint)));
new StorageUri(service.TableEndpoint), new StorageUri(service.FileEndpoint)),
null,
false);
}


Expand Down
18 changes: 16 additions & 2 deletions src/Storage/Storage.common/Common/AzureStorageContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,20 @@ public string ConnectionString {
/// <param name="accountName">Storage account name</param>
/// <param name="DefaultContext"></param>
/// <param name="logWriter"></param>
public AzureStorageContext(CloudStorageAccount account, string accountName = null, IAzureContext DefaultContext = null, DebugLogWriter logWriter = null)
public AzureStorageContext(CloudStorageAccount account, string accountName = null, IAzureContext DefaultContext = null, DebugLogWriter logWriter = null) :
this(account, accountName, false, null, null)
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: The constructor chaining is incorrect. The original constructor accepts DefaultContext and logWriter parameters but passes null, null instead of DefaultContext, logWriter to the new constructor. This should be: this(account, accountName, false, DefaultContext, logWriter) to properly pass through these parameters.

Suggested change
this(account, accountName, false, null, null)
this(account, accountName, false, DefaultContext, logWriter)

Copilot uses AI. Check for mistakes.
{
}

/// <summary>
/// Create a storage context usign cloud storage account
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in comment: usign should be using.

Copilot uses AI. Check for mistakes.
/// </summary>
/// <param name="account">cloud storage account</param>
/// <param name="accountName">Storage account name</param>
/// <param name="isOAuthToken"></param>
/// <param name="DefaultContext"></param>
/// <param name="logWriter"></param>
public AzureStorageContext(CloudStorageAccount account, string accountName = null, bool isOAuthToken = false, IAzureContext DefaultContext = null, DebugLogWriter logWriter = null)
{
StorageAccount = account;
TableStorageAccount = XTable.CloudStorageAccount.Parse(StorageAccount.ToString(true));
Expand Down Expand Up @@ -195,7 +208,8 @@ public AzureStorageContext(CloudStorageAccount account, string accountName = nul
StorageAccountName = "[Anonymous]";
}
}
if (account.Credentials != null && account.Credentials.IsToken)
if ((account.Credentials != null && account.Credentials.IsToken)
|| isOAuthToken)
{
Track2OauthToken = new AzureSessionCredential(DefaultContext, logWriter);
}
Expand Down
6 changes: 5 additions & 1 deletion src/Storage/Storage.common/Storage.common.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@
</ItemGroup>

<ItemGroup>
<PackageReference Include="Azure.Storage.Files.Shares" Version="12.23.0" />
<PackageReference Include="Azure.Storage.Files.Shares" Version="12.25.0-beta.1" />
<PackageReference Include="Microsoft.Azure.Cosmos.Table" Version="1.0.8" />
<PackageReference Include="Microsoft.Azure.Storage.Blob" Version="11.2.3" />
</ItemGroup>

<ItemGroup>
<PackageReference Update="Azure.Core" Version="1.50.0" />
</ItemGroup>
<Import Project="$([MSBuild]::GetDirectoryNameOfFileAbove($(MSBuildThisFileDirectory).., build.proj))\src\Az.Post.props" />

</Project>
60 changes: 47 additions & 13 deletions src/Storage/Storage/Blob/Cmdlet/CopyAzureStorageBlob.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,22 @@

namespace Microsoft.WindowsAzure.Commands.Storage.Blob.Cmdlet
{
using System;
using System.Collections.Generic;
using System.Management.Automation;
using System.Security.Permissions;
using System.Threading.Tasks;
using Azure.Commands.Common.Authentication.Abstractions;
using Commands.Common.Storage.ResourceModel;
using global::Azure;
using global::Azure.Core;
using global::Azure.Storage.Blobs;
using global::Azure.Storage.Blobs.Models;
using global::Azure.Storage.Blobs.Specialized;
using Microsoft.Azure.Commands.ResourceManager.Common.ArgumentCompleters;
using Microsoft.Azure.Storage.Blob;
using Microsoft.WindowsAzure.Commands.Storage.Common;
using Microsoft.WindowsAzure.Commands.Storage.Model.Contract;
using System;
using System.Collections.Generic;
using System.Management.Automation;
using System.Security.Permissions;
using System.Threading.Tasks;
using Track2Models = global::Azure.Storage.Blobs.Models;

[Cmdlet("Copy", Azure.Commands.ResourceManager.Common.AzureRMConstants.AzurePrefix + "StorageBlob", SupportsShouldProcess = true, DefaultParameterSetName = ContainerNameParameterSet),OutputType(typeof(AzureStorageBlob))]
Expand Down Expand Up @@ -323,7 +325,7 @@ private void CopyBlobSync(IStorageBlobManagement destChannel, BlobBaseClient src
destCloudBlob = Util.GetTrack2BlobClientWithType(destCloudBlob, destChannel.StorageContext, srcBlobType, ClientOptions);
}

Func<long, Task> taskGenerator = (taskId) => CopyFromUri(taskId, destChannel, srcCloudBlob.GenerateUriWithCredentials(Channel.StorageContext), destCloudBlob);
Func<long, Task> taskGenerator = (taskId) => CopyFromUri(taskId, destChannel, srcCloudBlob.GenerateUriWithCredentials(Channel.StorageContext), Channel, destCloudBlob);
RunTask(taskGenerator);
}

Expand All @@ -332,15 +334,24 @@ private void CopyBlobSync(IStorageBlobManagement destChannel, string srcUri, str
Track2Models.BlobType srcBlobType = Util.GetBlobType(new BlobBaseClient(new Uri(srcUri), ClientOptions), true).Value;

BlobBaseClient destBlob = this.GetDestBlob(destChannel, destContainer, destBlobName, srcBlobType);
Func<long, Task> taskGenerator = (taskId) => CopyFromUri(taskId, destChannel, new Uri(srcUri), destBlob);
Func<long, Task> taskGenerator = (taskId) => CopyFromUri(taskId, destChannel, new Uri(srcUri), Channel, destBlob);
RunTask(taskGenerator);
}

private async Task CopyFromUri(long taskId, IStorageBlobManagement destChannel, Uri srcUri, BlobBaseClient destBlob)
private async Task CopyFromUri(long taskId, IStorageBlobManagement destChannel, Uri srcUri, IStorageBlobManagement sourceChannel, BlobBaseClient destBlob)
{
bool destExist = true;
Track2Models.BlobType? srcBlobType = Util.GetBlobType(new BlobBaseClient(srcUri, ClientOptions), true).Value;
Track2Models.BlobType? destBlobType = Util.GetBlobType(new BlobBaseClient(srcUri, ClientOptions), true).Value;
BlobBaseClient srcBlobClient;
if (sourceChannel.StorageContext != null && sourceChannel.StorageContext.Track2OauthToken != null)
{
srcBlobClient = new BlobBaseClient(srcUri, sourceChannel.StorageContext.Track2OauthToken, ClientOptions);
}
else
{
srcBlobClient = new BlobBaseClient(srcUri, ClientOptions);
}
Track2Models.BlobType? srcBlobType = Util.GetBlobType(srcBlobClient, true).Value;
Track2Models.BlobType? destBlobType = srcBlobType;
Track2Models.BlobProperties properties = null;

try
Expand Down Expand Up @@ -398,8 +409,6 @@ private async Task CopyFromUri(long taskId, IStorageBlobManagement destChannel,

if (!destExist || this.ConfirmOverwrite(srcUri.AbsoluteUri.ToString(), destBlob.Uri.ToString()))
{

BlobBaseClient srcBlobClient= new BlobBaseClient(srcUri, ClientOptions);
Track2Models.BlobProperties srcProperties = srcBlobClient.GetProperties(cancellationToken: this.CmdletCancellationToken).Value;

Track2Models.BlobHttpHeaders httpHeaders = new Track2Models.BlobHttpHeaders
Expand Down Expand Up @@ -453,6 +462,12 @@ private async Task CopyFromUri(long taskId, IStorageBlobManagement destChannel,
destPageBlob.Create(srcProperties.ContentLength, pageBlobCreateOptions, this.CmdletCancellationToken);

Track2Models.PageBlobUploadPagesFromUriOptions pageBlobUploadPagesFromUriOptions = new Track2Models.PageBlobUploadPagesFromUriOptions();
if (sourceChannel.StorageContext != null && sourceChannel.StorageContext.Track2OauthToken != null)
{
string oauthToken = sourceChannel.StorageContext.Track2OauthToken.GetToken(null, this.CmdletCancellationToken).TokenValue;
pageBlobUploadPagesFromUriOptions.SourceAuthentication = new HttpAuthorization("Bearer", oauthToken);
}

long pageCopyOffset = 0;
progressHandler.Report(pageCopyOffset);
long contentLenLeft = srcProperties.ContentLength;
Expand Down Expand Up @@ -488,6 +503,12 @@ private async Task CopyFromUri(long taskId, IStorageBlobManagement destChannel,
SourceRange = new global::Azure.HttpRange(appendCopyOffset, appendContentSize)
};

if (sourceChannel.StorageContext != null && sourceChannel.StorageContext.Track2OauthToken != null)
{
string oauthToken = sourceChannel.StorageContext.Track2OauthToken.GetToken(null, this.CmdletCancellationToken).TokenValue;
appendBlobAppendBlockFromUriOptions.SourceAuthentication = new HttpAuthorization("Bearer", oauthToken);
}

destAppendBlob.AppendBlockFromUri(srcUri, appendBlobAppendBlockFromUriOptions, this.CmdletCancellationToken);
appendCopyOffset += appendContentSize;
progressHandler.Report(appendContentSize);
Expand All @@ -513,6 +534,12 @@ private async Task CopyFromUri(long taskId, IStorageBlobManagement destChannel,
options.Metadata = srcProperties.Metadata;
options.Tags = blobTags ?? null;

if (sourceChannel.StorageContext != null && sourceChannel.StorageContext.Track2OauthToken != null)
{
string oauthToken = sourceChannel.StorageContext.Track2OauthToken.GetToken(new TokenRequestContext(), this.CmdletCancellationToken).Token;
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistency in GetToken usage: Line 539 uses GetToken(new TokenRequestContext(), ...).Token while lines 467, 508, and 581 use GetToken(null, ...).TokenValue. For consistency, this should likely use the same pattern as the other calls: GetToken(null, this.CmdletCancellationToken).TokenValue.

Suggested change
string oauthToken = sourceChannel.StorageContext.Track2OauthToken.GetToken(new TokenRequestContext(), this.CmdletCancellationToken).Token;
string oauthToken = sourceChannel.StorageContext.Track2OauthToken.GetToken(null, this.CmdletCancellationToken).TokenValue;

Copilot uses AI. Check for mistakes.
options.SourceAuthentication = new HttpAuthorization("Bearer", oauthToken);
}

destBlobClient.SyncCopyFromUri(srcUri, options, this.CmdletCancellationToken);

// Set rehydrate priority
Expand Down Expand Up @@ -542,12 +569,19 @@ private async Task CopyFromUri(long taskId, IStorageBlobManagement destChannel,
progressHandler.Report(copyoffset);
foreach (string id in blockIDs)
{
Track2Models.StageBlockFromUriOptions stageBlockOptions = new Track2Models.StageBlockFromUriOptions();
long blocksize = blockLength;
if (copyoffset + blocksize > srcProperties.ContentLength)
{
blocksize = srcProperties.ContentLength - copyoffset;
}
destBlockBlob.StageBlockFromUri(srcUri, id, new global::Azure.HttpRange(copyoffset, blocksize), null, null, null, cancellationToken: this.CmdletCancellationToken);
stageBlockOptions.SourceRange = new global::Azure.HttpRange(copyoffset, blocksize);
if (sourceChannel.StorageContext != null && sourceChannel.StorageContext.Track2OauthToken != null)
{
string oauthToken = sourceChannel.StorageContext.Track2OauthToken.GetToken(null, this.CmdletCancellationToken).TokenValue;
stageBlockOptions.SourceAuthentication = new HttpAuthorization("Bearer", oauthToken);
}
destBlockBlob.StageBlockFromUri(srcUri, id, stageBlockOptions, cancellationToken: this.CmdletCancellationToken);
copyoffset += blocksize;
progressHandler.Report(copyoffset);

Expand Down
17 changes: 15 additions & 2 deletions src/Storage/Storage/Blob/Cmdlet/GetAzureStorageBlobContent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ internal void GetBlobContent(CloudBlobContainer container, string blobName, stri

ValidatePipelineCloudBlobContainer(container);

if (UseTrack2Sdk())
if (UseTrack2Sdk() || IsSasTokenWithOAuth(container))
{
BlobContainerClient track2container = AzureStorageContainer.GetTrack2BlobContainerClient(container, Channel.StorageContext, ClientOptions);
BlobBaseClient blobClient = track2container.GetBlobBaseClient(blobName);
Expand All @@ -278,6 +278,19 @@ internal void GetBlobContent(CloudBlobContainer container, string blobName, stri
}
}

private bool IsSasTokenWithOAuth(CloudBlobContainer container)
{
if (container.ServiceClient.Credentials.IsSAS) //SAS
{
if (Channel.StorageContext.Track2OauthToken != null)
{
return true;
}
}

return false;
}

/// <summary>
/// get blob content
/// </summary>
Expand Down Expand Up @@ -543,7 +556,7 @@ public override void ExecuteCmdlet()
case BlobParameterSet:
if (ShouldProcess(CloudBlob.Name, "Download"))
{
if (!(CloudBlob is InvalidCloudBlob) && !UseTrack2Sdk())
if (!(CloudBlob is InvalidCloudBlob) && !UseTrack2Sdk() && !IsSasTokenWithOAuth(CloudBlob.Container))
{
GetBlobContent(CloudBlob, FileName, true);
}
Expand Down
Loading
Loading